1 // cache.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // An Fst implementation that caches FST elements of a delayed
20 // computation.
21
22 #ifndef FST_LIB_CACHE_H__
23 #define FST_LIB_CACHE_H__
24
25 #include <vector>
26 using std::vector;
27 #include <list>
28
29 #include <fst/vector-fst.h>
30
31
32 DECLARE_bool(fst_default_cache_gc);
33 DECLARE_int64(fst_default_cache_gc_limit);
34
35 namespace fst {
36
37 struct CacheOptions {
38 bool gc; // enable GC
39 size_t gc_limit; // # of bytes allowed before GC
40
CacheOptionsCacheOptions41 CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {}
CacheOptionsCacheOptions42 CacheOptions()
43 : gc(FLAGS_fst_default_cache_gc),
44 gc_limit(FLAGS_fst_default_cache_gc_limit) {}
45 };
46
47 // A CacheStateAllocator allocates and frees CacheStates
48 // template <class S>
49 // struct CacheStateAllocator {
50 // S *Allocate(StateId s);
51 // void Free(S *state, StateId s);
52 // };
53 //
54
55 // A simple allocator class, can be overridden as needed,
56 // maintains a single entry cache.
57 template <class S>
58 struct DefaultCacheStateAllocator {
59 typedef typename S::Arc::StateId StateId;
60
DefaultCacheStateAllocatorDefaultCacheStateAllocator61 DefaultCacheStateAllocator() : mru_(NULL) { }
62
~DefaultCacheStateAllocatorDefaultCacheStateAllocator63 ~DefaultCacheStateAllocator() {
64 delete mru_;
65 }
66
AllocateDefaultCacheStateAllocator67 S *Allocate(StateId s) {
68 if (mru_) {
69 S *state = mru_;
70 mru_ = NULL;
71 state->Reset();
72 return state;
73 }
74 return new S();
75 }
76
FreeDefaultCacheStateAllocator77 void Free(S *state, StateId s) {
78 if (mru_) {
79 delete mru_;
80 }
81 mru_ = state;
82 }
83
84 private:
85 S *mru_;
86 };
87
88 // VectorState but additionally has a flags data member (see
89 // CacheState below). This class is used to cache FST elements with
90 // the flags used to indicate what has been cached. Use HasStart()
91 // HasFinal(), and HasArcs() to determine if cached and SetStart(),
92 // SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note
93 // you must set the final weight even if the state is non-final to
94 // mark it as cached. If the 'gc' option is 'false', cached items have
95 // the extent of the FST - minimizing computation. If the 'gc' option
96 // is 'true', garbage collection of states (not in use in an arc
97 // iterator and not 'protected') is performed, in a rough
98 // approximation of LRU order, when 'gc_limit' bytes is reached -
99 // controlling memory use. When 'gc_limit' is 0, special optimizations
100 // apply - minimizing memory use.
101
102 template <class S, class C = DefaultCacheStateAllocator<S> >
103 class CacheBaseImpl : public VectorFstBaseImpl<S> {
104 public:
105 typedef S State;
106 typedef C Allocator;
107 typedef typename State::Arc Arc;
108 typedef typename Arc::Weight Weight;
109 typedef typename Arc::StateId StateId;
110
111 using FstImpl<Arc>::Type;
112 using FstImpl<Arc>::Properties;
113 using FstImpl<Arc>::SetProperties;
114 using VectorFstBaseImpl<State>::NumStates;
115 using VectorFstBaseImpl<State>::Start;
116 using VectorFstBaseImpl<State>::AddState;
117 using VectorFstBaseImpl<State>::SetState;
118 using VectorFstBaseImpl<State>::ReserveStates;
119
120 explicit CacheBaseImpl(C *allocator = 0)
cache_start_(false)121 : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0),
122 cache_first_state_id_(kNoStateId), cache_first_state_(0),
123 cache_gc_(FLAGS_fst_default_cache_gc), cache_size_(0),
124 cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit ||
125 FLAGS_fst_default_cache_gc_limit == 0 ?
126 FLAGS_fst_default_cache_gc_limit : kMinCacheLimit),
127 protect_(false) {
128 allocator_ = allocator ? allocator : new C();
129 }
130
131 explicit CacheBaseImpl(const CacheOptions &opts, C *allocator = 0)
cache_start_(false)132 : cache_start_(false), nknown_states_(0),
133 min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
134 cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0),
135 cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ?
136 opts.gc_limit : kMinCacheLimit),
137 protect_(false) {
138 allocator_ = allocator ? allocator : new C();
139 }
140
141 // Preserve gc parameters. If preserve_cache true, also preserves
142 // cache data.
143 CacheBaseImpl(const CacheBaseImpl<S, C> &impl, bool preserve_cache = false)
144 : VectorFstBaseImpl<S>(), cache_start_(false), nknown_states_(0),
145 min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
146 cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0),
147 cache_limit_(impl.cache_limit_),
148 protect_(impl.protect_) {
149 allocator_ = new C();
150 if (preserve_cache) {
151 cache_start_ = impl.cache_start_;
152 nknown_states_ = impl.nknown_states_;
153 expanded_states_ = impl.expanded_states_;
154 min_unexpanded_state_id_ = impl.min_unexpanded_state_id_;
155 if (impl.cache_first_state_id_ != kNoStateId) {
156 cache_first_state_id_ = impl.cache_first_state_id_;
157 cache_first_state_ = allocator_->Allocate(cache_first_state_id_);
158 *cache_first_state_ = *impl.cache_first_state_;
159 }
160 cache_states_ = impl.cache_states_;
161 cache_size_ = impl.cache_size_;
162 ReserveStates(impl.NumStates());
163 for (StateId s = 0; s < impl.NumStates(); ++s) {
164 const S *state =
165 static_cast<const VectorFstBaseImpl<S> &>(impl).GetState(s);
166 if (state) {
167 S *copied_state = allocator_->Allocate(s);
168 *copied_state = *state;
169 AddState(copied_state);
170 } else {
171 AddState(0);
172 }
173 }
174 VectorFstBaseImpl<S>::SetStart(impl.Start());
175 }
176 }
177
~CacheBaseImpl()178 ~CacheBaseImpl() {
179 allocator_->Free(cache_first_state_, cache_first_state_id_);
180 delete allocator_;
181 }
182
183 // Gets a state from its ID; state must exist.
GetState(StateId s)184 const S *GetState(StateId s) const {
185 if (s == cache_first_state_id_)
186 return cache_first_state_;
187 else
188 return VectorFstBaseImpl<S>::GetState(s);
189 }
190
191 // Gets a state from its ID; state must exist.
GetState(StateId s)192 S *GetState(StateId s) {
193 if (s == cache_first_state_id_)
194 return cache_first_state_;
195 else
196 return VectorFstBaseImpl<S>::GetState(s);
197 }
198
199 // Gets a state from its ID; return 0 if it doesn't exist.
CheckState(StateId s)200 const S *CheckState(StateId s) const {
201 if (s == cache_first_state_id_)
202 return cache_first_state_;
203 else if (s < NumStates())
204 return VectorFstBaseImpl<S>::GetState(s);
205 else
206 return 0;
207 }
208
209 // Gets a state from its ID; add it if necessary.
210 S *ExtendState(StateId s);
211
SetStart(StateId s)212 void SetStart(StateId s) {
213 VectorFstBaseImpl<S>::SetStart(s);
214 cache_start_ = true;
215 if (s >= nknown_states_)
216 nknown_states_ = s + 1;
217 }
218
SetFinal(StateId s,Weight w)219 void SetFinal(StateId s, Weight w) {
220 S *state = ExtendState(s);
221 state->final = w;
222 state->flags |= kCacheFinal | kCacheRecent | kCacheModified;
223 }
224
225 // AddArc adds a single arc to state s and does incremental cache
226 // book-keeping. For efficiency, prefer PushArc and SetArcs below
227 // when possible.
AddArc(StateId s,const Arc & arc)228 void AddArc(StateId s, const Arc &arc) {
229 S *state = ExtendState(s);
230 state->arcs.push_back(arc);
231 if (arc.ilabel == 0) {
232 ++state->niepsilons;
233 }
234 if (arc.olabel == 0) {
235 ++state->noepsilons;
236 }
237 const Arc *parc = state->arcs.empty() ? 0 : &(state->arcs.back());
238 SetProperties(AddArcProperties(Properties(), s, arc, parc));
239 state->flags |= kCacheModified;
240 if (cache_gc_ && s != cache_first_state_id_ &&
241 !(state->flags & kCacheProtect)) {
242 cache_size_ += sizeof(Arc);
243 if (cache_size_ > cache_limit_)
244 GC(s, false);
245 }
246 }
247
248 // Adds a single arc to state s but delays cache book-keeping.
249 // SetArcs must be called when all PushArc calls at a state are
250 // complete. Do not mix with calls to AddArc.
PushArc(StateId s,const Arc & arc)251 void PushArc(StateId s, const Arc &arc) {
252 S *state = ExtendState(s);
253 state->arcs.push_back(arc);
254 }
255
256 // Marks arcs of state s as cached and does cache book-keeping after all
257 // calls to PushArc have been completed. Do not mix with calls to AddArc.
SetArcs(StateId s)258 void SetArcs(StateId s) {
259 S *state = ExtendState(s);
260 vector<Arc> &arcs = state->arcs;
261 state->niepsilons = state->noepsilons = 0;
262 for (size_t a = 0; a < arcs.size(); ++a) {
263 const Arc &arc = arcs[a];
264 if (arc.nextstate >= nknown_states_)
265 nknown_states_ = arc.nextstate + 1;
266 if (arc.ilabel == 0)
267 ++state->niepsilons;
268 if (arc.olabel == 0)
269 ++state->noepsilons;
270 }
271 ExpandedState(s);
272 state->flags |= kCacheArcs | kCacheRecent | kCacheModified;
273 if (cache_gc_ && s != cache_first_state_id_ &&
274 !(state->flags & kCacheProtect)) {
275 cache_size_ += arcs.capacity() * sizeof(Arc);
276 if (cache_size_ > cache_limit_)
277 GC(s, false);
278 }
279 };
280
ReserveArcs(StateId s,size_t n)281 void ReserveArcs(StateId s, size_t n) {
282 S *state = ExtendState(s);
283 state->arcs.reserve(n);
284 }
285
DeleteArcs(StateId s,size_t n)286 void DeleteArcs(StateId s, size_t n) {
287 S *state = ExtendState(s);
288 const vector<Arc> &arcs = state->arcs;
289 for (size_t i = 0; i < n; ++i) {
290 size_t j = arcs.size() - i - 1;
291 if (arcs[j].ilabel == 0)
292 --state->niepsilons;
293 if (arcs[j].olabel == 0)
294 --state->noepsilons;
295 }
296
297 state->arcs.resize(arcs.size() - n);
298 SetProperties(DeleteArcsProperties(Properties()));
299 state->flags |= kCacheModified;
300 if (cache_gc_ && s != cache_first_state_id_ &&
301 !(state->flags & kCacheProtect)) {
302 cache_size_ -= n * sizeof(Arc);
303 }
304 }
305
DeleteArcs(StateId s)306 void DeleteArcs(StateId s) {
307 S *state = ExtendState(s);
308 size_t n = state->arcs.size();
309 state->niepsilons = 0;
310 state->noepsilons = 0;
311 state->arcs.clear();
312 SetProperties(DeleteArcsProperties(Properties()));
313 state->flags |= kCacheModified;
314 if (cache_gc_ && s != cache_first_state_id_ &&
315 !(state->flags & kCacheProtect)) {
316 cache_size_ -= n * sizeof(Arc);
317 }
318 }
319
DeleteStates(const vector<StateId> & dstates)320 void DeleteStates(const vector<StateId> &dstates) {
321 size_t old_num_states = NumStates();
322 vector<StateId> newid(old_num_states, 0);
323 for (size_t i = 0; i < dstates.size(); ++i)
324 newid[dstates[i]] = kNoStateId;
325 StateId nstates = 0;
326 for (StateId s = 0; s < old_num_states; ++s) {
327 if (newid[s] != kNoStateId) {
328 newid[s] = nstates;
329 ++nstates;
330 }
331 }
332 // just for states_.resize(), does unnecessary walk.
333 VectorFstBaseImpl<S>::DeleteStates(dstates);
334 SetProperties(DeleteStatesProperties(Properties()));
335 // Update list of cached states.
336 typename list<StateId>::iterator siter = cache_states_.begin();
337 while (siter != cache_states_.end()) {
338 if (newid[*siter] != kNoStateId) {
339 *siter = newid[*siter];
340 ++siter;
341 } else {
342 cache_states_.erase(siter++);
343 }
344 }
345 }
346
DeleteStates()347 void DeleteStates() {
348 cache_states_.clear();
349 allocator_->Free(cache_first_state_, cache_first_state_id_);
350 for (int s = 0; s < NumStates(); ++s) {
351 allocator_->Free(VectorFstBaseImpl<S>::GetState(s), s);
352 SetState(s, 0);
353 }
354 nknown_states_ = 0;
355 min_unexpanded_state_id_ = 0;
356 cache_first_state_id_ = kNoStateId;
357 cache_first_state_ = 0;
358 cache_size_ = 0;
359 cache_start_ = false;
360 VectorFstBaseImpl<State>::DeleteStates();
361 SetProperties(DeleteAllStatesProperties(Properties(),
362 kExpanded | kMutable));
363 }
364
365 // Is the start state cached?
HasStart()366 bool HasStart() const {
367 if (!cache_start_ && Properties(kError))
368 cache_start_ = true;
369 return cache_start_;
370 }
371
372 // Is the final weight of state s cached?
HasFinal(StateId s)373 bool HasFinal(StateId s) const {
374 const S *state = CheckState(s);
375 if (state && state->flags & kCacheFinal) {
376 state->flags |= kCacheRecent;
377 return true;
378 } else {
379 return false;
380 }
381 }
382
383 // Are arcs of state s cached?
HasArcs(StateId s)384 bool HasArcs(StateId s) const {
385 const S *state = CheckState(s);
386 if (state && state->flags & kCacheArcs) {
387 state->flags |= kCacheRecent;
388 return true;
389 } else {
390 return false;
391 }
392 }
393
Final(StateId s)394 Weight Final(StateId s) const {
395 const S *state = GetState(s);
396 return state->final;
397 }
398
NumArcs(StateId s)399 size_t NumArcs(StateId s) const {
400 const S *state = GetState(s);
401 return state->arcs.size();
402 }
403
NumInputEpsilons(StateId s)404 size_t NumInputEpsilons(StateId s) const {
405 const S *state = GetState(s);
406 return state->niepsilons;
407 }
408
NumOutputEpsilons(StateId s)409 size_t NumOutputEpsilons(StateId s) const {
410 const S *state = GetState(s);
411 return state->noepsilons;
412 }
413
414 // Provides information needed for generic arc iterator.
InitArcIterator(StateId s,ArcIteratorData<Arc> * data)415 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
416 const S *state = GetState(s);
417 data->base = 0;
418 data->narcs = state->arcs.size();
419 data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0;
420 data->ref_count = &(state->ref_count);
421 ++(*data->ref_count);
422 }
423
424 // Number of known states.
NumKnownStates()425 StateId NumKnownStates() const { return nknown_states_; }
426
427 // Update number of known states taking in account the existence of state s.
UpdateNumKnownStates(StateId s)428 void UpdateNumKnownStates(StateId s) {
429 if (s >= nknown_states_)
430 nknown_states_ = s + 1;
431 }
432
433 // Find the mininum never-expanded state Id
MinUnexpandedState()434 StateId MinUnexpandedState() const {
435 while (min_unexpanded_state_id_ < expanded_states_.size() &&
436 expanded_states_[min_unexpanded_state_id_])
437 ++min_unexpanded_state_id_;
438 return min_unexpanded_state_id_;
439 }
440
441 // Removes from cache_states_ and uncaches (not referenced-counted
442 // or protected) states that have not been accessed since the last
443 // GC until at most cache_fraction * cache_limit_ bytes are cached.
444 // If that fails to free enough, recurs uncaching recently visited
445 // states as well. If still unable to free enough memory, then
446 // widens cache_limit_ to fulfill condition.
447 void GC(StateId current, bool free_recent, float cache_fraction = 0.666);
448
449 // Setc/clears GC protection: if true, new states are protected
450 // from garbage collection.
GCProtect(bool on)451 void GCProtect(bool on) { protect_ = on; }
452
ExpandedState(StateId s)453 void ExpandedState(StateId s) {
454 if (s < min_unexpanded_state_id_)
455 return;
456 while (expanded_states_.size() <= s)
457 expanded_states_.push_back(false);
458 expanded_states_[s] = true;
459 }
460
GetAllocator()461 C *GetAllocator() const {
462 return allocator_;
463 }
464
465 // Caching on/off switch, limit and size accessors.
GetCacheGc()466 bool GetCacheGc() const { return cache_gc_; }
GetCacheLimit()467 size_t GetCacheLimit() const { return cache_limit_; }
GetCacheSize()468 size_t GetCacheSize() const { return cache_size_; }
469
470 private:
471 static const size_t kMinCacheLimit = 8096; // Minimum (non-zero) cache limit
472
473 static const uint32 kCacheFinal = 0x0001; // Final weight has been cached
474 static const uint32 kCacheArcs = 0x0002; // Arcs have been cached
475 static const uint32 kCacheRecent = 0x0004; // Mark as visited since GC
476 static const uint32 kCacheProtect = 0x0008; // Mark state as GC protected
477
478 public:
479 static const uint32 kCacheModified = 0x0010; // Mark state as modified
480 static const uint32 kCacheFlags = kCacheFinal | kCacheArcs | kCacheRecent
481 | kCacheProtect | kCacheModified;
482
483 private:
484 C *allocator_; // used to allocate new states
485 mutable bool cache_start_; // Is the start state cached?
486 StateId nknown_states_; // # of known states
487 vector<bool> expanded_states_; // states that have been expanded
488 mutable StateId min_unexpanded_state_id_; // minimum never-expanded state Id
489 StateId cache_first_state_id_; // First cached state id
490 S *cache_first_state_; // First cached state
491 list<StateId> cache_states_; // list of currently cached states
492 bool cache_gc_; // enable GC
493 size_t cache_size_; // # of bytes cached
494 size_t cache_limit_; // # of bytes allowed before GC
495 bool protect_; // Protect new states from GC
496
497 void operator=(const CacheBaseImpl<S, C> &impl); // disallow
498 };
499
500 // Gets a state from its ID; add it if necessary.
501 template <class S, class C>
ExtendState(typename S::Arc::StateId s)502 S *CacheBaseImpl<S, C>::ExtendState(typename S::Arc::StateId s) {
503 // If 'protect_' true and a new state, protects from garbage collection.
504 if (s == cache_first_state_id_) {
505 return cache_first_state_; // Return 1st cached state
506 } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
507 cache_first_state_id_ = s; // Remember 1st cached state
508 cache_first_state_ = allocator_->Allocate(s);
509 if (protect_) cache_first_state_->flags |= kCacheProtect;
510 return cache_first_state_;
511 } else if (cache_first_state_id_ != kNoStateId &&
512 cache_first_state_->ref_count == 0 &&
513 !(cache_first_state_->flags & kCacheProtect)) {
514 // With Default allocator, the Free and Allocate will reuse the same S*.
515 allocator_->Free(cache_first_state_, cache_first_state_id_);
516 cache_first_state_id_ = s;
517 cache_first_state_ = allocator_->Allocate(s);
518 if (protect_) cache_first_state_->flags |= kCacheProtect;
519 return cache_first_state_; // Return 1st cached state
520 } else {
521 while (NumStates() <= s) // Add state to main cache
522 AddState(0);
523 S *state = VectorFstBaseImpl<S>::GetState(s);
524 if (!state) {
525 state = allocator_->Allocate(s);
526 if (protect_) state->flags |= kCacheProtect;
527 SetState(s, state);
528 if (cache_first_state_id_ != kNoStateId) { // Forget 1st cached state
529 while (NumStates() <= cache_first_state_id_)
530 AddState(0);
531 SetState(cache_first_state_id_, cache_first_state_);
532 if (cache_gc_ && !(cache_first_state_->flags & kCacheProtect)) {
533 cache_states_.push_back(cache_first_state_id_);
534 cache_size_ += sizeof(S) +
535 cache_first_state_->arcs.capacity() * sizeof(Arc);
536 }
537 cache_limit_ = kMinCacheLimit;
538 cache_first_state_id_ = kNoStateId;
539 cache_first_state_ = 0;
540 }
541 if (cache_gc_ && !protect_) {
542 cache_states_.push_back(s);
543 cache_size_ += sizeof(S);
544 if (cache_size_ > cache_limit_)
545 GC(s, false);
546 }
547 }
548 return state;
549 }
550 }
551
552 // Removes from cache_states_ and uncaches (not referenced-counted or
553 // protected) states that have not been accessed since the last GC
554 // until at most cache_fraction * cache_limit_ bytes are cached. If
555 // that fails to free enough, recurs uncaching recently visited states
556 // as well. If still unable to free enough memory, then widens cache_limit_
557 // to fulfill condition.
558 template <class S, class C>
GC(typename S::Arc::StateId current,bool free_recent,float cache_fraction)559 void CacheBaseImpl<S, C>::GC(typename S::Arc::StateId current,
560 bool free_recent, float cache_fraction) {
561 if (!cache_gc_)
562 return;
563 VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
564 << "), free recently cached = " << free_recent
565 << ", cache size = " << cache_size_
566 << ", cache frac = " << cache_fraction
567 << ", cache limit = " << cache_limit_ << "\n";
568 typename list<StateId>::iterator siter = cache_states_.begin();
569
570 size_t cache_target = cache_fraction * cache_limit_;
571 while (siter != cache_states_.end()) {
572 StateId s = *siter;
573 S* state = VectorFstBaseImpl<S>::GetState(s);
574 if (cache_size_ > cache_target && state->ref_count == 0 &&
575 (free_recent || !(state->flags & kCacheRecent)) && s != current) {
576 cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
577 allocator_->Free(state, s);
578 SetState(s, 0);
579 cache_states_.erase(siter++);
580 } else {
581 state->flags &= ~kCacheRecent;
582 ++siter;
583 }
584 }
585 if (!free_recent && cache_size_ > cache_target) { // recurses on recent
586 GC(current, true);
587 } else if (cache_target > 0) { // widens cache limit
588 while (cache_size_ > cache_target) {
589 cache_limit_ *= 2;
590 cache_target *= 2;
591 }
592 } else if (cache_size_ > 0) {
593 FSTERROR() << "CacheImpl:GC: Unable to free all cached states";
594 }
595 VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
596 << "), free recently cached = " << free_recent
597 << ", cache size = " << cache_size_
598 << ", cache frac = " << cache_fraction
599 << ", cache limit = " << cache_limit_ << "\n";
600 }
601
602 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheFinal;
603 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheArcs;
604 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheRecent;
605 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheModified;
606 template <class S, class C> const size_t CacheBaseImpl<S, C>::kMinCacheLimit;
607
608 // Arcs implemented by an STL vector per state. Similar to VectorState
609 // but adds flags and ref count to keep track of what has been cached.
610 template <class A>
611 struct CacheState {
612 typedef A Arc;
613 typedef typename A::Weight Weight;
614 typedef typename A::StateId StateId;
615
CacheStateCacheState616 CacheState() : final(Weight::Zero()), flags(0), ref_count(0) {}
617
ResetCacheState618 void Reset() {
619 flags = 0;
620 ref_count = 0;
621 arcs.resize(0);
622 }
623
624 Weight final; // Final weight
625 vector<A> arcs; // Arcs represenation
626 size_t niepsilons; // # of input epsilons
627 size_t noepsilons; // # of output epsilons
628 mutable uint32 flags;
629 mutable int ref_count;
630 };
631
632 // A CacheBaseImpl with a commonly used CacheState.
633 template <class A>
634 class CacheImpl : public CacheBaseImpl< CacheState<A> > {
635 public:
636 typedef CacheState<A> State;
637
CacheImpl()638 CacheImpl() {}
639
CacheImpl(const CacheOptions & opts)640 explicit CacheImpl(const CacheOptions &opts)
641 : CacheBaseImpl< CacheState<A> >(opts) {}
642
643 CacheImpl(const CacheImpl<A> &impl, bool preserve_cache = false)
644 : CacheBaseImpl<State>(impl, preserve_cache) {}
645
646 private:
647 void operator=(const CacheImpl<State> &impl); // disallow
648 };
649
650
651 // Use this to make a state iterator for a CacheBaseImpl-derived Fst,
652 // which must have type 'State' defined. Note this iterator only
653 // returns those states reachable from the initial state, so consider
654 // implementing a class-specific one.
655 template <class F>
656 class CacheStateIterator : public StateIteratorBase<typename F::Arc> {
657 public:
658 typedef typename F::Arc Arc;
659 typedef typename Arc::StateId StateId;
660 typedef typename F::State State;
661 typedef CacheBaseImpl<State> Impl;
662
CacheStateIterator(const F & fst,Impl * impl)663 CacheStateIterator(const F &fst, Impl *impl)
664 : fst_(fst), impl_(impl), s_(0) {
665 fst_.Start(); // force start state
666 }
667
Done()668 bool Done() const {
669 if (s_ < impl_->NumKnownStates())
670 return false;
671 if (s_ < impl_->NumKnownStates())
672 return false;
673 for (StateId u = impl_->MinUnexpandedState();
674 u < impl_->NumKnownStates();
675 u = impl_->MinUnexpandedState()) {
676 // force state expansion
677 ArcIterator<F> aiter(fst_, u);
678 aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache);
679 for (; !aiter.Done(); aiter.Next())
680 impl_->UpdateNumKnownStates(aiter.Value().nextstate);
681 impl_->ExpandedState(u);
682 if (s_ < impl_->NumKnownStates())
683 return false;
684 }
685 return true;
686 }
687
Value()688 StateId Value() const { return s_; }
689
Next()690 void Next() { ++s_; }
691
Reset()692 void Reset() { s_ = 0; }
693
694 private:
695 // This allows base class virtual access to non-virtual derived-
696 // class members of the same name. It makes the derived class more
697 // efficient to use but unsafe to further derive.
Done_()698 virtual bool Done_() const { return Done(); }
Value_()699 virtual StateId Value_() const { return Value(); }
Next_()700 virtual void Next_() { Next(); }
Reset_()701 virtual void Reset_() { Reset(); }
702
703 const F &fst_;
704 Impl *impl_;
705 StateId s_;
706 };
707
708
709 // Use this to make an arc iterator for a CacheBaseImpl-derived Fst,
710 // which must have types 'Arc' and 'State' defined.
711 template <class F,
712 class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > >
713 class CacheArcIterator {
714 public:
715 typedef typename F::Arc Arc;
716 typedef typename F::State State;
717 typedef typename Arc::StateId StateId;
718 typedef CacheBaseImpl<State, C> Impl;
719
CacheArcIterator(Impl * impl,StateId s)720 CacheArcIterator(Impl *impl, StateId s) : i_(0) {
721 state_ = impl->ExtendState(s);
722 ++state_->ref_count;
723 }
724
~CacheArcIterator()725 ~CacheArcIterator() { --state_->ref_count; }
726
Done()727 bool Done() const { return i_ >= state_->arcs.size(); }
728
Value()729 const Arc& Value() const { return state_->arcs[i_]; }
730
Next()731 void Next() { ++i_; }
732
Position()733 size_t Position() const { return i_; }
734
Reset()735 void Reset() { i_ = 0; }
736
Seek(size_t a)737 void Seek(size_t a) { i_ = a; }
738
Flags()739 uint32 Flags() const {
740 return kArcValueFlags;
741 }
742
SetFlags(uint32 flags,uint32 mask)743 void SetFlags(uint32 flags, uint32 mask) {}
744
745 private:
746 const State *state_;
747 size_t i_;
748
749 DISALLOW_COPY_AND_ASSIGN(CacheArcIterator);
750 };
751
752 // Use this to make a mutable arc iterator for a CacheBaseImpl-derived Fst,
753 // which must have types 'Arc' and 'State' defined.
754 template <class F,
755 class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > >
756 class CacheMutableArcIterator
757 : public MutableArcIteratorBase<typename F::Arc> {
758 public:
759 typedef typename F::State State;
760 typedef typename F::Arc Arc;
761 typedef typename Arc::StateId StateId;
762 typedef typename Arc::Weight Weight;
763 typedef CacheBaseImpl<State, C> Impl;
764
765 // You will need to call MutateCheck() in the constructor.
CacheMutableArcIterator(Impl * impl,StateId s)766 CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) {
767 state_ = impl_->ExtendState(s_);
768 ++state_->ref_count;
769 };
770
~CacheMutableArcIterator()771 ~CacheMutableArcIterator() {
772 --state_->ref_count;
773 }
774
Done()775 bool Done() const { return i_ >= state_->arcs.size(); }
776
Value()777 const Arc& Value() const { return state_->arcs[i_]; }
778
Next()779 void Next() { ++i_; }
780
Position()781 size_t Position() const { return i_; }
782
Reset()783 void Reset() { i_ = 0; }
784
Seek(size_t a)785 void Seek(size_t a) { i_ = a; }
786
SetValue(const Arc & arc)787 void SetValue(const Arc& arc) {
788 state_->flags |= CacheBaseImpl<State, C>::kCacheModified;
789 uint64 properties = impl_->Properties();
790 Arc& oarc = state_->arcs[i_];
791 if (oarc.ilabel != oarc.olabel)
792 properties &= ~kNotAcceptor;
793 if (oarc.ilabel == 0) {
794 --state_->niepsilons;
795 properties &= ~kIEpsilons;
796 if (oarc.olabel == 0)
797 properties &= ~kEpsilons;
798 }
799 if (oarc.olabel == 0) {
800 --state_->noepsilons;
801 properties &= ~kOEpsilons;
802 }
803 if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One())
804 properties &= ~kWeighted;
805 oarc = arc;
806 if (arc.ilabel != arc.olabel) {
807 properties |= kNotAcceptor;
808 properties &= ~kAcceptor;
809 }
810 if (arc.ilabel == 0) {
811 ++state_->niepsilons;
812 properties |= kIEpsilons;
813 properties &= ~kNoIEpsilons;
814 if (arc.olabel == 0) {
815 properties |= kEpsilons;
816 properties &= ~kNoEpsilons;
817 }
818 }
819 if (arc.olabel == 0) {
820 ++state_->noepsilons;
821 properties |= kOEpsilons;
822 properties &= ~kNoOEpsilons;
823 }
824 if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) {
825 properties |= kWeighted;
826 properties &= ~kUnweighted;
827 }
828 properties &= kSetArcProperties | kAcceptor | kNotAcceptor |
829 kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons |
830 kOEpsilons | kNoOEpsilons | kWeighted | kUnweighted;
831 impl_->SetProperties(properties);
832 }
833
Flags()834 uint32 Flags() const {
835 return kArcValueFlags;
836 }
837
SetFlags(uint32 f,uint32 m)838 void SetFlags(uint32 f, uint32 m) {}
839
840 private:
Done_()841 virtual bool Done_() const { return Done(); }
Value_()842 virtual const Arc& Value_() const { return Value(); }
Next_()843 virtual void Next_() { Next(); }
Position_()844 virtual size_t Position_() const { return Position(); }
Reset_()845 virtual void Reset_() { Reset(); }
Seek_(size_t a)846 virtual void Seek_(size_t a) { Seek(a); }
SetValue_(const Arc & a)847 virtual void SetValue_(const Arc &a) { SetValue(a); }
Flags_()848 uint32 Flags_() const { return Flags(); }
SetFlags_(uint32 f,uint32 m)849 void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); }
850
851 size_t i_;
852 StateId s_;
853 Impl *impl_;
854 State *state_;
855
856 DISALLOW_COPY_AND_ASSIGN(CacheMutableArcIterator);
857 };
858
859 } // namespace fst
860
861 #endif // FST_LIB_CACHE_H__
862