1 // cache.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // An Fst implementation that caches FST elements of a delayed
20 // computation.
21 
22 #ifndef FST_LIB_CACHE_H__
23 #define FST_LIB_CACHE_H__
24 
25 #include <vector>
26 using std::vector;
27 #include <list>
28 
29 #include <fst/vector-fst.h>
30 
31 
32 DECLARE_bool(fst_default_cache_gc);
33 DECLARE_int64(fst_default_cache_gc_limit);
34 
35 namespace fst {
36 
37 struct CacheOptions {
38   bool gc;          // enable GC
39   size_t gc_limit;  // # of bytes allowed before GC
40 
CacheOptionsCacheOptions41   CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {}
CacheOptionsCacheOptions42   CacheOptions()
43       : gc(FLAGS_fst_default_cache_gc),
44         gc_limit(FLAGS_fst_default_cache_gc_limit) {}
45 };
46 
47 // A CacheStateAllocator allocates and frees CacheStates
48 // template <class S>
49 // struct CacheStateAllocator {
50 //   S *Allocate(StateId s);
51 //   void Free(S *state, StateId s);
52 // };
53 //
54 
55 // A simple allocator class, can be overridden as needed,
56 // maintains a single entry cache.
57 template <class S>
58 struct DefaultCacheStateAllocator {
59   typedef typename S::Arc::StateId StateId;
60 
DefaultCacheStateAllocatorDefaultCacheStateAllocator61   DefaultCacheStateAllocator() : mru_(NULL) { }
62 
~DefaultCacheStateAllocatorDefaultCacheStateAllocator63   ~DefaultCacheStateAllocator() {
64     delete mru_;
65   }
66 
AllocateDefaultCacheStateAllocator67   S *Allocate(StateId s) {
68     if (mru_) {
69       S *state = mru_;
70       mru_ = NULL;
71       state->Reset();
72       return state;
73     }
74     return new S();
75   }
76 
FreeDefaultCacheStateAllocator77   void Free(S *state, StateId s) {
78     if (mru_) {
79       delete mru_;
80     }
81     mru_ = state;
82   }
83 
84  private:
85   S *mru_;
86 };
87 
88 // VectorState but additionally has a flags data member (see
89 // CacheState below). This class is used to cache FST elements with
90 // the flags used to indicate what has been cached. Use HasStart()
91 // HasFinal(), and HasArcs() to determine if cached and SetStart(),
92 // SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note
93 // you must set the final weight even if the state is non-final to
94 // mark it as cached. If the 'gc' option is 'false', cached items have
95 // the extent of the FST - minimizing computation. If the 'gc' option
96 // is 'true', garbage collection of states (not in use in an arc
97 // iterator and not 'protected') is performed, in a rough
98 // approximation of LRU order, when 'gc_limit' bytes is reached -
99 // controlling memory use. When 'gc_limit' is 0, special optimizations
100 // apply - minimizing memory use.
101 
102 template <class S, class C = DefaultCacheStateAllocator<S> >
103 class CacheBaseImpl : public VectorFstBaseImpl<S> {
104  public:
105   typedef S State;
106   typedef C Allocator;
107   typedef typename State::Arc Arc;
108   typedef typename Arc::Weight Weight;
109   typedef typename Arc::StateId StateId;
110 
111   using FstImpl<Arc>::Type;
112   using FstImpl<Arc>::Properties;
113   using FstImpl<Arc>::SetProperties;
114   using VectorFstBaseImpl<State>::NumStates;
115   using VectorFstBaseImpl<State>::Start;
116   using VectorFstBaseImpl<State>::AddState;
117   using VectorFstBaseImpl<State>::SetState;
118   using VectorFstBaseImpl<State>::ReserveStates;
119 
120   explicit CacheBaseImpl(C *allocator = 0)
cache_start_(false)121       : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0),
122         cache_first_state_id_(kNoStateId), cache_first_state_(0),
123         cache_gc_(FLAGS_fst_default_cache_gc),  cache_size_(0),
124         cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit ||
125                      FLAGS_fst_default_cache_gc_limit == 0 ?
126                      FLAGS_fst_default_cache_gc_limit : kMinCacheLimit),
127         protect_(false) {
128     allocator_ = allocator ? allocator : new C();
129   }
130 
131   explicit CacheBaseImpl(const CacheOptions &opts, C *allocator = 0)
cache_start_(false)132       : cache_start_(false), nknown_states_(0),
133         min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
134         cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0),
135         cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ?
136                      opts.gc_limit : kMinCacheLimit),
137         protect_(false) {
138     allocator_ = allocator ? allocator : new C();
139   }
140 
141   // Preserve gc parameters. If preserve_cache true, also preserves
142   // cache data.
143   CacheBaseImpl(const CacheBaseImpl<S, C> &impl, bool preserve_cache = false)
144       : VectorFstBaseImpl<S>(), cache_start_(false), nknown_states_(0),
145         min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
146         cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0),
147         cache_limit_(impl.cache_limit_),
148         protect_(impl.protect_) {
149     allocator_ = new C();
150     if (preserve_cache) {
151       cache_start_ = impl.cache_start_;
152       nknown_states_ = impl.nknown_states_;
153       expanded_states_ = impl.expanded_states_;
154       min_unexpanded_state_id_ = impl.min_unexpanded_state_id_;
155       if (impl.cache_first_state_id_ != kNoStateId) {
156         cache_first_state_id_ = impl.cache_first_state_id_;
157         cache_first_state_ = allocator_->Allocate(cache_first_state_id_);
158         *cache_first_state_ = *impl.cache_first_state_;
159       }
160       cache_states_ = impl.cache_states_;
161       cache_size_ = impl.cache_size_;
162       ReserveStates(impl.NumStates());
163       for (StateId s = 0; s < impl.NumStates(); ++s) {
164         const S *state =
165             static_cast<const VectorFstBaseImpl<S> &>(impl).GetState(s);
166         if (state) {
167           S *copied_state = allocator_->Allocate(s);
168           *copied_state = *state;
169           AddState(copied_state);
170         } else {
171           AddState(0);
172         }
173       }
174       VectorFstBaseImpl<S>::SetStart(impl.Start());
175     }
176   }
177 
~CacheBaseImpl()178   ~CacheBaseImpl() {
179     allocator_->Free(cache_first_state_, cache_first_state_id_);
180     delete allocator_;
181   }
182 
183   // Gets a state from its ID; state must exist.
GetState(StateId s)184   const S *GetState(StateId s) const {
185     if (s == cache_first_state_id_)
186       return cache_first_state_;
187     else
188       return VectorFstBaseImpl<S>::GetState(s);
189   }
190 
191   // Gets a state from its ID; state must exist.
GetState(StateId s)192   S *GetState(StateId s) {
193     if (s == cache_first_state_id_)
194       return cache_first_state_;
195     else
196       return VectorFstBaseImpl<S>::GetState(s);
197   }
198 
199   // Gets a state from its ID; return 0 if it doesn't exist.
CheckState(StateId s)200   const S *CheckState(StateId s) const {
201     if (s == cache_first_state_id_)
202       return cache_first_state_;
203     else if (s < NumStates())
204       return VectorFstBaseImpl<S>::GetState(s);
205     else
206       return 0;
207   }
208 
209   // Gets a state from its ID; add it if necessary.
210   S *ExtendState(StateId s);
211 
SetStart(StateId s)212   void SetStart(StateId s) {
213     VectorFstBaseImpl<S>::SetStart(s);
214     cache_start_ = true;
215     if (s >= nknown_states_)
216       nknown_states_ = s + 1;
217   }
218 
SetFinal(StateId s,Weight w)219   void SetFinal(StateId s, Weight w) {
220     S *state = ExtendState(s);
221     state->final = w;
222     state->flags |= kCacheFinal | kCacheRecent | kCacheModified;
223   }
224 
225   // AddArc adds a single arc to state s and does incremental cache
226   // book-keeping.  For efficiency, prefer PushArc and SetArcs below
227   // when possible.
AddArc(StateId s,const Arc & arc)228   void AddArc(StateId s, const Arc &arc) {
229     S *state = ExtendState(s);
230     state->arcs.push_back(arc);
231     if (arc.ilabel == 0) {
232       ++state->niepsilons;
233     }
234     if (arc.olabel == 0) {
235       ++state->noepsilons;
236     }
237     const Arc *parc = state->arcs.empty() ? 0 : &(state->arcs.back());
238     SetProperties(AddArcProperties(Properties(), s, arc, parc));
239     state->flags |= kCacheModified;
240     if (cache_gc_ && s != cache_first_state_id_ &&
241         !(state->flags & kCacheProtect)) {
242       cache_size_ += sizeof(Arc);
243       if (cache_size_ > cache_limit_)
244         GC(s, false);
245     }
246   }
247 
248   // Adds a single arc to state s but delays cache book-keeping.
249   // SetArcs must be called when all PushArc calls at a state are
250   // complete.  Do not mix with calls to AddArc.
PushArc(StateId s,const Arc & arc)251   void PushArc(StateId s, const Arc &arc) {
252     S *state = ExtendState(s);
253     state->arcs.push_back(arc);
254   }
255 
256   // Marks arcs of state s as cached and does cache book-keeping after all
257   // calls to PushArc have been completed.  Do not mix with calls to AddArc.
SetArcs(StateId s)258   void SetArcs(StateId s) {
259     S *state = ExtendState(s);
260     vector<Arc> &arcs = state->arcs;
261     state->niepsilons = state->noepsilons = 0;
262     for (size_t a = 0; a < arcs.size(); ++a) {
263       const Arc &arc = arcs[a];
264       if (arc.nextstate >= nknown_states_)
265         nknown_states_ = arc.nextstate + 1;
266       if (arc.ilabel == 0)
267         ++state->niepsilons;
268       if (arc.olabel == 0)
269         ++state->noepsilons;
270     }
271     ExpandedState(s);
272     state->flags |= kCacheArcs | kCacheRecent | kCacheModified;
273     if (cache_gc_ && s != cache_first_state_id_ &&
274         !(state->flags & kCacheProtect)) {
275       cache_size_ += arcs.capacity() * sizeof(Arc);
276       if (cache_size_ > cache_limit_)
277         GC(s, false);
278     }
279   };
280 
ReserveArcs(StateId s,size_t n)281   void ReserveArcs(StateId s, size_t n) {
282     S *state = ExtendState(s);
283     state->arcs.reserve(n);
284   }
285 
DeleteArcs(StateId s,size_t n)286   void DeleteArcs(StateId s, size_t n) {
287     S *state = ExtendState(s);
288     const vector<Arc> &arcs = state->arcs;
289     for (size_t i = 0; i < n; ++i) {
290       size_t j = arcs.size() - i - 1;
291       if (arcs[j].ilabel == 0)
292         --state->niepsilons;
293       if (arcs[j].olabel == 0)
294         --state->noepsilons;
295     }
296 
297     state->arcs.resize(arcs.size() - n);
298     SetProperties(DeleteArcsProperties(Properties()));
299     state->flags |= kCacheModified;
300     if (cache_gc_ && s != cache_first_state_id_ &&
301         !(state->flags & kCacheProtect)) {
302       cache_size_ -= n * sizeof(Arc);
303     }
304   }
305 
DeleteArcs(StateId s)306   void DeleteArcs(StateId s) {
307     S *state = ExtendState(s);
308     size_t n = state->arcs.size();
309     state->niepsilons = 0;
310     state->noepsilons = 0;
311     state->arcs.clear();
312     SetProperties(DeleteArcsProperties(Properties()));
313     state->flags |= kCacheModified;
314     if (cache_gc_ && s != cache_first_state_id_ &&
315         !(state->flags & kCacheProtect)) {
316       cache_size_ -= n * sizeof(Arc);
317     }
318   }
319 
DeleteStates(const vector<StateId> & dstates)320   void DeleteStates(const vector<StateId> &dstates) {
321     size_t old_num_states = NumStates();
322     vector<StateId> newid(old_num_states, 0);
323     for (size_t i = 0; i < dstates.size(); ++i)
324       newid[dstates[i]] = kNoStateId;
325     StateId nstates = 0;
326     for (StateId s = 0; s < old_num_states; ++s) {
327       if (newid[s] != kNoStateId) {
328         newid[s] = nstates;
329         ++nstates;
330       }
331     }
332     // just for states_.resize(), does unnecessary walk.
333     VectorFstBaseImpl<S>::DeleteStates(dstates);
334     SetProperties(DeleteStatesProperties(Properties()));
335     // Update list of cached states.
336     typename list<StateId>::iterator siter = cache_states_.begin();
337     while (siter != cache_states_.end()) {
338       if (newid[*siter] != kNoStateId) {
339         *siter = newid[*siter];
340         ++siter;
341       } else {
342         cache_states_.erase(siter++);
343       }
344     }
345   }
346 
DeleteStates()347   void DeleteStates() {
348     cache_states_.clear();
349     allocator_->Free(cache_first_state_, cache_first_state_id_);
350     for (int s = 0; s < NumStates(); ++s) {
351       allocator_->Free(VectorFstBaseImpl<S>::GetState(s), s);
352       SetState(s, 0);
353     }
354     nknown_states_ = 0;
355     min_unexpanded_state_id_ = 0;
356     cache_first_state_id_ = kNoStateId;
357     cache_first_state_ = 0;
358     cache_size_ = 0;
359     cache_start_ = false;
360     VectorFstBaseImpl<State>::DeleteStates();
361     SetProperties(DeleteAllStatesProperties(Properties(),
362                                             kExpanded | kMutable));
363   }
364 
365   // Is the start state cached?
HasStart()366   bool HasStart() const {
367     if (!cache_start_ && Properties(kError))
368       cache_start_ = true;
369     return cache_start_;
370   }
371 
372   // Is the final weight of state s cached?
HasFinal(StateId s)373   bool HasFinal(StateId s) const {
374     const S *state = CheckState(s);
375     if (state && state->flags & kCacheFinal) {
376       state->flags |= kCacheRecent;
377       return true;
378     } else {
379       return false;
380     }
381   }
382 
383   // Are arcs of state s cached?
HasArcs(StateId s)384   bool HasArcs(StateId s) const {
385     const S *state = CheckState(s);
386     if (state && state->flags & kCacheArcs) {
387       state->flags |= kCacheRecent;
388       return true;
389     } else {
390       return false;
391     }
392   }
393 
Final(StateId s)394   Weight Final(StateId s) const {
395     const S *state = GetState(s);
396     return state->final;
397   }
398 
NumArcs(StateId s)399   size_t NumArcs(StateId s) const {
400     const S *state = GetState(s);
401     return state->arcs.size();
402   }
403 
NumInputEpsilons(StateId s)404   size_t NumInputEpsilons(StateId s) const {
405     const S *state = GetState(s);
406     return state->niepsilons;
407   }
408 
NumOutputEpsilons(StateId s)409   size_t NumOutputEpsilons(StateId s) const {
410     const S *state = GetState(s);
411     return state->noepsilons;
412   }
413 
414   // Provides information needed for generic arc iterator.
InitArcIterator(StateId s,ArcIteratorData<Arc> * data)415   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
416     const S *state = GetState(s);
417     data->base = 0;
418     data->narcs = state->arcs.size();
419     data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0;
420     data->ref_count = &(state->ref_count);
421     ++(*data->ref_count);
422   }
423 
424   // Number of known states.
NumKnownStates()425   StateId NumKnownStates() const { return nknown_states_; }
426 
427   // Update number of known states taking in account the existence of state s.
UpdateNumKnownStates(StateId s)428   void UpdateNumKnownStates(StateId s) {
429     if (s >= nknown_states_)
430       nknown_states_ = s + 1;
431   }
432 
433   // Find the mininum never-expanded state Id
MinUnexpandedState()434   StateId MinUnexpandedState() const {
435     while (min_unexpanded_state_id_ < expanded_states_.size() &&
436           expanded_states_[min_unexpanded_state_id_])
437       ++min_unexpanded_state_id_;
438     return min_unexpanded_state_id_;
439   }
440 
441   // Removes from cache_states_ and uncaches (not referenced-counted
442   // or protected) states that have not been accessed since the last
443   // GC until at most cache_fraction * cache_limit_ bytes are cached.
444   // If that fails to free enough, recurs uncaching recently visited
445   // states as well. If still unable to free enough memory, then
446   // widens cache_limit_ to fulfill condition.
447   void GC(StateId current, bool free_recent,  float cache_fraction = 0.666);
448 
449   // Setc/clears GC protection: if true, new states are protected
450   // from garbage collection.
GCProtect(bool on)451   void GCProtect(bool on) { protect_ = on; }
452 
ExpandedState(StateId s)453   void ExpandedState(StateId s) {
454     if (s < min_unexpanded_state_id_)
455       return;
456     while (expanded_states_.size() <= s)
457       expanded_states_.push_back(false);
458     expanded_states_[s] = true;
459   }
460 
GetAllocator()461   C *GetAllocator() const {
462     return allocator_;
463   }
464 
465   // Caching on/off switch, limit and size accessors.
GetCacheGc()466   bool GetCacheGc() const { return cache_gc_; }
GetCacheLimit()467   size_t GetCacheLimit() const { return cache_limit_; }
GetCacheSize()468   size_t GetCacheSize() const { return cache_size_; }
469 
470  private:
471   static const size_t kMinCacheLimit = 8096;   // Minimum (non-zero) cache limit
472 
473   static const uint32 kCacheFinal =    0x0001;  // Final weight has been cached
474   static const uint32 kCacheArcs =     0x0002;  // Arcs have been cached
475   static const uint32 kCacheRecent =   0x0004;  // Mark as visited since GC
476   static const uint32 kCacheProtect =  0x0008;  // Mark state as GC protected
477 
478  public:
479   static const uint32 kCacheModified = 0x0010;  // Mark state as modified
480   static const uint32 kCacheFlags = kCacheFinal | kCacheArcs | kCacheRecent
481       | kCacheProtect | kCacheModified;
482 
483  private:
484   C *allocator_;                             // used to allocate new states
485   mutable bool cache_start_;                 // Is the start state cached?
486   StateId nknown_states_;                    // # of known states
487   vector<bool> expanded_states_;             // states that have been expanded
488   mutable StateId min_unexpanded_state_id_;  // minimum never-expanded state Id
489   StateId cache_first_state_id_;             // First cached state id
490   S *cache_first_state_;                     // First cached state
491   list<StateId> cache_states_;               // list of currently cached states
492   bool cache_gc_;                            // enable GC
493   size_t cache_size_;                        // # of bytes cached
494   size_t cache_limit_;                       // # of bytes allowed before GC
495   bool protect_;                             // Protect new states from GC
496 
497   void operator=(const CacheBaseImpl<S, C> &impl);    // disallow
498 };
499 
500 // Gets a state from its ID; add it if necessary.
501 template <class S, class C>
ExtendState(typename S::Arc::StateId s)502 S *CacheBaseImpl<S, C>::ExtendState(typename S::Arc::StateId s) {
503   // If 'protect_' true and a new state, protects from garbage collection.
504   if (s == cache_first_state_id_) {
505     return cache_first_state_;                   // Return 1st cached state
506   } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
507     cache_first_state_id_ = s;                   // Remember 1st cached state
508     cache_first_state_ = allocator_->Allocate(s);
509     if (protect_) cache_first_state_->flags |= kCacheProtect;
510     return cache_first_state_;
511   } else if (cache_first_state_id_ != kNoStateId &&
512              cache_first_state_->ref_count == 0 &&
513              !(cache_first_state_->flags & kCacheProtect)) {
514     // With Default allocator, the Free and Allocate will reuse the same S*.
515     allocator_->Free(cache_first_state_, cache_first_state_id_);
516     cache_first_state_id_ = s;
517     cache_first_state_ = allocator_->Allocate(s);
518     if (protect_) cache_first_state_->flags |= kCacheProtect;
519     return cache_first_state_;                   // Return 1st cached state
520   } else {
521     while (NumStates() <= s)                     // Add state to main cache
522       AddState(0);
523     S *state = VectorFstBaseImpl<S>::GetState(s);
524     if (!state) {
525       state = allocator_->Allocate(s);
526       if (protect_) state->flags |= kCacheProtect;
527       SetState(s, state);
528       if (cache_first_state_id_ != kNoStateId) {  // Forget 1st cached state
529         while (NumStates() <= cache_first_state_id_)
530           AddState(0);
531         SetState(cache_first_state_id_, cache_first_state_);
532         if (cache_gc_ && !(cache_first_state_->flags & kCacheProtect)) {
533           cache_states_.push_back(cache_first_state_id_);
534           cache_size_ += sizeof(S) +
535                          cache_first_state_->arcs.capacity() * sizeof(Arc);
536         }
537         cache_limit_ = kMinCacheLimit;
538         cache_first_state_id_ = kNoStateId;
539         cache_first_state_ = 0;
540       }
541       if (cache_gc_ && !protect_) {
542         cache_states_.push_back(s);
543         cache_size_ += sizeof(S);
544         if (cache_size_ > cache_limit_)
545           GC(s, false);
546       }
547     }
548     return state;
549   }
550 }
551 
552 // Removes from cache_states_ and uncaches (not referenced-counted or
553 // protected) states that have not been accessed since the last GC
554 // until at most cache_fraction * cache_limit_ bytes are cached.  If
555 // that fails to free enough, recurs uncaching recently visited states
556 // as well. If still unable to free enough memory, then widens cache_limit_
557 // to fulfill condition.
558 template <class S, class C>
GC(typename S::Arc::StateId current,bool free_recent,float cache_fraction)559 void CacheBaseImpl<S, C>::GC(typename S::Arc::StateId current,
560                              bool free_recent, float cache_fraction) {
561   if (!cache_gc_)
562     return;
563   VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
564           << "), free recently cached = " << free_recent
565           << ", cache size = " << cache_size_
566           << ", cache frac = " << cache_fraction
567           << ", cache limit = " << cache_limit_ << "\n";
568   typename list<StateId>::iterator siter = cache_states_.begin();
569 
570   size_t cache_target = cache_fraction * cache_limit_;
571   while (siter != cache_states_.end()) {
572     StateId s = *siter;
573     S* state = VectorFstBaseImpl<S>::GetState(s);
574     if (cache_size_ > cache_target && state->ref_count == 0 &&
575         (free_recent || !(state->flags & kCacheRecent)) && s != current) {
576       cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
577       allocator_->Free(state, s);
578       SetState(s, 0);
579       cache_states_.erase(siter++);
580     } else {
581       state->flags &= ~kCacheRecent;
582       ++siter;
583     }
584   }
585   if (!free_recent && cache_size_ > cache_target) {   // recurses on recent
586     GC(current, true);
587   } else if (cache_target > 0) {                      // widens cache limit
588     while (cache_size_ > cache_target) {
589       cache_limit_ *= 2;
590       cache_target *= 2;
591     }
592   } else if (cache_size_ > 0) {
593     FSTERROR() << "CacheImpl:GC: Unable to free all cached states";
594   }
595   VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
596           << "), free recently cached = " << free_recent
597           << ", cache size = " << cache_size_
598           << ", cache frac = " << cache_fraction
599           << ", cache limit = " << cache_limit_ << "\n";
600 }
601 
602 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheFinal;
603 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheArcs;
604 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheRecent;
605 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheModified;
606 template <class S, class C> const size_t CacheBaseImpl<S, C>::kMinCacheLimit;
607 
608 // Arcs implemented by an STL vector per state. Similar to VectorState
609 // but adds flags and ref count to keep track of what has been cached.
610 template <class A>
611 struct CacheState {
612   typedef A Arc;
613   typedef typename A::Weight Weight;
614   typedef typename A::StateId StateId;
615 
CacheStateCacheState616   CacheState() :  final(Weight::Zero()), flags(0), ref_count(0) {}
617 
ResetCacheState618   void Reset() {
619     flags = 0;
620     ref_count = 0;
621     arcs.resize(0);
622   }
623 
624   Weight final;              // Final weight
625   vector<A> arcs;            // Arcs represenation
626   size_t niepsilons;         // # of input epsilons
627   size_t noepsilons;         // # of output epsilons
628   mutable uint32 flags;
629   mutable int ref_count;
630 };
631 
632 // A CacheBaseImpl with a commonly used CacheState.
633 template <class A>
634 class CacheImpl : public CacheBaseImpl< CacheState<A> > {
635  public:
636   typedef CacheState<A> State;
637 
CacheImpl()638   CacheImpl() {}
639 
CacheImpl(const CacheOptions & opts)640   explicit CacheImpl(const CacheOptions &opts)
641       : CacheBaseImpl< CacheState<A> >(opts) {}
642 
643   CacheImpl(const CacheImpl<A> &impl, bool preserve_cache = false)
644       : CacheBaseImpl<State>(impl, preserve_cache) {}
645 
646  private:
647   void operator=(const CacheImpl<State> &impl);    // disallow
648 };
649 
650 
651 // Use this to make a state iterator for a CacheBaseImpl-derived Fst,
652 // which must have type 'State' defined.  Note this iterator only
653 // returns those states reachable from the initial state, so consider
654 // implementing a class-specific one.
655 template <class F>
656 class CacheStateIterator : public StateIteratorBase<typename F::Arc> {
657  public:
658   typedef typename F::Arc Arc;
659   typedef typename Arc::StateId StateId;
660   typedef typename F::State State;
661   typedef CacheBaseImpl<State> Impl;
662 
CacheStateIterator(const F & fst,Impl * impl)663   CacheStateIterator(const F &fst, Impl *impl)
664       : fst_(fst), impl_(impl), s_(0) {
665         fst_.Start();  // force start state
666       }
667 
Done()668   bool Done() const {
669     if (s_ < impl_->NumKnownStates())
670       return false;
671     if (s_ < impl_->NumKnownStates())
672       return false;
673     for (StateId u = impl_->MinUnexpandedState();
674          u < impl_->NumKnownStates();
675          u = impl_->MinUnexpandedState()) {
676       // force state expansion
677       ArcIterator<F> aiter(fst_, u);
678       aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache);
679       for (; !aiter.Done(); aiter.Next())
680         impl_->UpdateNumKnownStates(aiter.Value().nextstate);
681       impl_->ExpandedState(u);
682       if (s_ < impl_->NumKnownStates())
683         return false;
684     }
685     return true;
686   }
687 
Value()688   StateId Value() const { return s_; }
689 
Next()690   void Next() { ++s_; }
691 
Reset()692   void Reset() { s_ = 0; }
693 
694  private:
695   // This allows base class virtual access to non-virtual derived-
696   // class members of the same name. It makes the derived class more
697   // efficient to use but unsafe to further derive.
Done_()698   virtual bool Done_() const { return Done(); }
Value_()699   virtual StateId Value_() const { return Value(); }
Next_()700   virtual void Next_() { Next(); }
Reset_()701   virtual void Reset_() { Reset(); }
702 
703   const F &fst_;
704   Impl *impl_;
705   StateId s_;
706 };
707 
708 
709 // Use this to make an arc iterator for a CacheBaseImpl-derived Fst,
710 // which must have types 'Arc' and 'State' defined.
711 template <class F,
712           class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > >
713 class CacheArcIterator {
714  public:
715   typedef typename F::Arc Arc;
716   typedef typename F::State State;
717   typedef typename Arc::StateId StateId;
718   typedef CacheBaseImpl<State, C> Impl;
719 
CacheArcIterator(Impl * impl,StateId s)720   CacheArcIterator(Impl *impl, StateId s) : i_(0) {
721     state_ = impl->ExtendState(s);
722     ++state_->ref_count;
723   }
724 
~CacheArcIterator()725   ~CacheArcIterator() { --state_->ref_count;  }
726 
Done()727   bool Done() const { return i_ >= state_->arcs.size(); }
728 
Value()729   const Arc& Value() const { return state_->arcs[i_]; }
730 
Next()731   void Next() { ++i_; }
732 
Position()733   size_t Position() const { return i_; }
734 
Reset()735   void Reset() { i_ = 0; }
736 
Seek(size_t a)737   void Seek(size_t a) { i_ = a; }
738 
Flags()739   uint32 Flags() const {
740     return kArcValueFlags;
741   }
742 
SetFlags(uint32 flags,uint32 mask)743   void SetFlags(uint32 flags, uint32 mask) {}
744 
745  private:
746   const State *state_;
747   size_t i_;
748 
749   DISALLOW_COPY_AND_ASSIGN(CacheArcIterator);
750 };
751 
752 // Use this to make a mutable arc iterator for a CacheBaseImpl-derived Fst,
753 // which must have types 'Arc' and 'State' defined.
754 template <class F,
755           class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > >
756 class CacheMutableArcIterator
757     : public MutableArcIteratorBase<typename F::Arc> {
758  public:
759   typedef typename F::State State;
760   typedef typename F::Arc Arc;
761   typedef typename Arc::StateId StateId;
762   typedef typename Arc::Weight Weight;
763   typedef CacheBaseImpl<State, C> Impl;
764 
765   // You will need to call MutateCheck() in the constructor.
CacheMutableArcIterator(Impl * impl,StateId s)766   CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) {
767     state_ = impl_->ExtendState(s_);
768     ++state_->ref_count;
769   };
770 
~CacheMutableArcIterator()771   ~CacheMutableArcIterator() {
772     --state_->ref_count;
773   }
774 
Done()775   bool Done() const { return i_ >= state_->arcs.size(); }
776 
Value()777   const Arc& Value() const { return state_->arcs[i_]; }
778 
Next()779   void Next() { ++i_; }
780 
Position()781   size_t Position() const { return i_; }
782 
Reset()783   void Reset() { i_ = 0; }
784 
Seek(size_t a)785   void Seek(size_t a) { i_ = a; }
786 
SetValue(const Arc & arc)787   void SetValue(const Arc& arc) {
788     state_->flags |= CacheBaseImpl<State, C>::kCacheModified;
789     uint64 properties = impl_->Properties();
790     Arc& oarc = state_->arcs[i_];
791     if (oarc.ilabel != oarc.olabel)
792       properties &= ~kNotAcceptor;
793     if (oarc.ilabel == 0) {
794       --state_->niepsilons;
795       properties &= ~kIEpsilons;
796       if (oarc.olabel == 0)
797         properties &= ~kEpsilons;
798     }
799     if (oarc.olabel == 0) {
800       --state_->noepsilons;
801       properties &= ~kOEpsilons;
802     }
803     if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One())
804       properties &= ~kWeighted;
805     oarc = arc;
806     if (arc.ilabel != arc.olabel) {
807       properties |= kNotAcceptor;
808       properties &= ~kAcceptor;
809     }
810     if (arc.ilabel == 0) {
811       ++state_->niepsilons;
812       properties |= kIEpsilons;
813       properties &= ~kNoIEpsilons;
814       if (arc.olabel == 0) {
815         properties |= kEpsilons;
816         properties &= ~kNoEpsilons;
817       }
818     }
819     if (arc.olabel == 0) {
820       ++state_->noepsilons;
821       properties |= kOEpsilons;
822       properties &= ~kNoOEpsilons;
823     }
824     if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) {
825       properties |= kWeighted;
826       properties &= ~kUnweighted;
827     }
828     properties &= kSetArcProperties | kAcceptor | kNotAcceptor |
829         kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons |
830         kOEpsilons | kNoOEpsilons | kWeighted | kUnweighted;
831     impl_->SetProperties(properties);
832   }
833 
Flags()834   uint32 Flags() const {
835     return kArcValueFlags;
836   }
837 
SetFlags(uint32 f,uint32 m)838   void SetFlags(uint32 f, uint32 m) {}
839 
840  private:
Done_()841   virtual bool Done_() const { return Done(); }
Value_()842   virtual const Arc& Value_() const { return Value(); }
Next_()843   virtual void Next_() { Next(); }
Position_()844   virtual size_t Position_() const { return Position(); }
Reset_()845   virtual void Reset_() { Reset(); }
Seek_(size_t a)846   virtual void Seek_(size_t a) { Seek(a); }
SetValue_(const Arc & a)847   virtual void SetValue_(const Arc &a) { SetValue(a); }
Flags_()848   uint32 Flags_() const { return Flags(); }
SetFlags_(uint32 f,uint32 m)849   void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); }
850 
851   size_t i_;
852   StateId s_;
853   Impl *impl_;
854   State *state_;
855 
856   DISALLOW_COPY_AND_ASSIGN(CacheMutableArcIterator);
857 };
858 
859 }  // namespace fst
860 
861 #endif  // FST_LIB_CACHE_H__
862