1 // accumulator.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley@google.com (Michael Riley) 17 // 18 // \file 19 // Classes to accumulate arc weights. Useful for weight lookahead. 20 21 #ifndef FST_LIB_ACCUMULATOR_H__ 22 #define FST_LIB_ACCUMULATOR_H__ 23 24 #include <algorithm> 25 #include <functional> 26 #include <tr1/unordered_map> 27 using std::tr1::unordered_map; 28 using std::tr1::unordered_multimap; 29 #include <vector> 30 using std::vector; 31 32 #include <fst/arcfilter.h> 33 #include <fst/arcsort.h> 34 #include <fst/dfs-visit.h> 35 #include <fst/expanded-fst.h> 36 #include <fst/replace.h> 37 38 namespace fst { 39 40 // This class accumulates arc weights using the semiring Plus(). 41 template <class A> 42 class DefaultAccumulator { 43 public: 44 typedef A Arc; 45 typedef typename A::StateId StateId; 46 typedef typename A::Weight Weight; 47 DefaultAccumulator()48 DefaultAccumulator() {} 49 DefaultAccumulator(const DefaultAccumulator<A> & acc)50 DefaultAccumulator(const DefaultAccumulator<A> &acc) {} 51 52 void Init(const Fst<A>& fst, bool copy = false) {} 53 SetState(StateId)54 void SetState(StateId) {} 55 Sum(Weight w,Weight v)56 Weight Sum(Weight w, Weight v) { 57 return Plus(w, v); 58 } 59 60 template <class ArcIterator> Sum(Weight w,ArcIterator * aiter,ssize_t begin,ssize_t end)61 Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, 62 ssize_t end) { 63 Weight sum = w; 64 aiter->Seek(begin); 65 for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) 66 sum = Plus(sum, aiter->Value().weight); 67 return sum; 68 } 69 Error()70 bool Error() const { return false; } 71 72 private: 73 void operator=(const DefaultAccumulator<A> &); // Disallow 74 }; 75 76 77 // This class accumulates arc weights using the log semiring Plus() 78 // assuming an arc weight has a WeightConvert specialization to 79 // and from log64 weights. 80 template <class A> 81 class LogAccumulator { 82 public: 83 typedef A Arc; 84 typedef typename A::StateId StateId; 85 typedef typename A::Weight Weight; 86 LogAccumulator()87 LogAccumulator() {} 88 LogAccumulator(const LogAccumulator<A> & acc)89 LogAccumulator(const LogAccumulator<A> &acc) {} 90 91 void Init(const Fst<A>& fst, bool copy = false) {} 92 SetState(StateId)93 void SetState(StateId) {} 94 Sum(Weight w,Weight v)95 Weight Sum(Weight w, Weight v) { 96 return LogPlus(w, v); 97 } 98 99 template <class ArcIterator> Sum(Weight w,ArcIterator * aiter,ssize_t begin,ssize_t end)100 Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, 101 ssize_t end) { 102 Weight sum = w; 103 aiter->Seek(begin); 104 for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) 105 sum = LogPlus(sum, aiter->Value().weight); 106 return sum; 107 } 108 Error()109 bool Error() const { return false; } 110 111 private: LogPosExp(double x)112 double LogPosExp(double x) { return log(1.0F + exp(-x)); } 113 LogPlus(Weight w,Weight v)114 Weight LogPlus(Weight w, Weight v) { 115 double f1 = to_log_weight_(w).Value(); 116 double f2 = to_log_weight_(v).Value(); 117 if (f1 > f2) 118 return to_weight_(f2 - LogPosExp(f1 - f2)); 119 else 120 return to_weight_(f1 - LogPosExp(f2 - f1)); 121 } 122 123 WeightConvert<Weight, Log64Weight> to_log_weight_; 124 WeightConvert<Log64Weight, Weight> to_weight_; 125 126 void operator=(const LogAccumulator<A> &); // Disallow 127 }; 128 129 130 // Stores shareable data for fast log accumulator copies. 131 class FastLogAccumulatorData { 132 public: FastLogAccumulatorData()133 FastLogAccumulatorData() {} 134 Weights()135 vector<double> *Weights() { return &weights_; } WeightPositions()136 vector<ssize_t> *WeightPositions() { return &weight_positions_; } WeightEnd()137 double *WeightEnd() { return &(weights_[weights_.size() - 1]); }; RefCount()138 int RefCount() const { return ref_count_.count(); } IncrRefCount()139 int IncrRefCount() { return ref_count_.Incr(); } DecrRefCount()140 int DecrRefCount() { return ref_count_.Decr(); } 141 142 private: 143 // Cummulative weight per state for all states s.t. # of arcs > 144 // arc_limit_ with arcs in order. Special first element per state 145 // being Log64Weight::Zero(); 146 vector<double> weights_; 147 // Maps from state to corresponding beginning weight position in 148 // weights_. Position -1 means no pre-computed weights for that 149 // state. 150 vector<ssize_t> weight_positions_; 151 RefCounter ref_count_; // Reference count. 152 153 DISALLOW_COPY_AND_ASSIGN(FastLogAccumulatorData); 154 }; 155 156 157 // This class accumulates arc weights using the log semiring Plus() 158 // assuming an arc weight has a WeightConvert specialization to and 159 // from log64 weights. The member function Init(fst) has to be called 160 // to setup pre-computed weight information. 161 template <class A> 162 class FastLogAccumulator { 163 public: 164 typedef A Arc; 165 typedef typename A::StateId StateId; 166 typedef typename A::Weight Weight; 167 168 explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10) arc_limit_(arc_limit)169 : arc_limit_(arc_limit), 170 arc_period_(arc_period), 171 data_(new FastLogAccumulatorData()), 172 error_(false) {} 173 FastLogAccumulator(const FastLogAccumulator<A> & acc)174 FastLogAccumulator(const FastLogAccumulator<A> &acc) 175 : arc_limit_(acc.arc_limit_), 176 arc_period_(acc.arc_period_), 177 data_(acc.data_), 178 error_(acc.error_) { 179 data_->IncrRefCount(); 180 } 181 ~FastLogAccumulator()182 ~FastLogAccumulator() { 183 if (!data_->DecrRefCount()) 184 delete data_; 185 } 186 SetState(StateId s)187 void SetState(StateId s) { 188 vector<double> &weights = *data_->Weights(); 189 vector<ssize_t> &weight_positions = *data_->WeightPositions(); 190 191 if (weight_positions.size() <= s) { 192 FSTERROR() << "FastLogAccumulator::SetState: invalid state id."; 193 error_ = true; 194 return; 195 } 196 197 ssize_t pos = weight_positions[s]; 198 if (pos >= 0) 199 state_weights_ = &(weights[pos]); 200 else 201 state_weights_ = 0; 202 } 203 Sum(Weight w,Weight v)204 Weight Sum(Weight w, Weight v) { 205 return LogPlus(w, v); 206 } 207 208 template <class ArcIterator> Sum(Weight w,ArcIterator * aiter,ssize_t begin,ssize_t end)209 Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, 210 ssize_t end) { 211 if (error_) return Weight::NoWeight(); 212 Weight sum = w; 213 // Finds begin and end of pre-stored weights 214 ssize_t index_begin = -1, index_end = -1; 215 ssize_t stored_begin = end, stored_end = end; 216 if (state_weights_ != 0) { 217 index_begin = begin > 0 ? (begin - 1)/ arc_period_ + 1 : 0; 218 index_end = end / arc_period_; 219 stored_begin = index_begin * arc_period_; 220 stored_end = index_end * arc_period_; 221 } 222 // Computes sum before pre-stored weights 223 if (begin < stored_begin) { 224 ssize_t pos_end = min(stored_begin, end); 225 aiter->Seek(begin); 226 for (ssize_t pos = begin; pos < pos_end; aiter->Next(), ++pos) 227 sum = LogPlus(sum, aiter->Value().weight); 228 } 229 // Computes sum between pre-stored weights 230 if (stored_begin < stored_end) { 231 sum = LogPlus(sum, LogMinus(state_weights_[index_end], 232 state_weights_[index_begin])); 233 } 234 // Computes sum after pre-stored weights 235 if (stored_end < end) { 236 ssize_t pos_start = max(stored_begin, stored_end); 237 aiter->Seek(pos_start); 238 for (ssize_t pos = pos_start; pos < end; aiter->Next(), ++pos) 239 sum = LogPlus(sum, aiter->Value().weight); 240 } 241 return sum; 242 } 243 244 template <class F> 245 void Init(const F &fst, bool copy = false) { 246 if (copy) 247 return; 248 vector<double> &weights = *data_->Weights(); 249 vector<ssize_t> &weight_positions = *data_->WeightPositions(); 250 if (!weights.empty() || arc_limit_ < arc_period_) { 251 FSTERROR() << "FastLogAccumulator: initialization error."; 252 error_ = true; 253 return; 254 } 255 weight_positions.reserve(CountStates(fst)); 256 257 ssize_t weight_position = 0; 258 for(StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { 259 StateId s = siter.Value(); 260 if (fst.NumArcs(s) >= arc_limit_) { 261 double sum = FloatLimits<double>::PosInfinity(); 262 weight_positions.push_back(weight_position); 263 weights.push_back(sum); 264 ++weight_position; 265 ssize_t narcs = 0; 266 for(ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) { 267 const A &arc = aiter.Value(); 268 sum = LogPlus(sum, arc.weight); 269 // Stores cumulative weight distribution per arc_period_. 270 if (++narcs % arc_period_ == 0) { 271 weights.push_back(sum); 272 ++weight_position; 273 } 274 } 275 } else { 276 weight_positions.push_back(-1); 277 } 278 } 279 } 280 Error()281 bool Error() const { return error_; } 282 283 private: LogPosExp(double x)284 double LogPosExp(double x) { 285 return x == FloatLimits<double>::PosInfinity() ? 286 0.0 : log(1.0F + exp(-x)); 287 } 288 LogMinusExp(double x)289 double LogMinusExp(double x) { 290 return x == FloatLimits<double>::PosInfinity() ? 291 0.0 : log(1.0F - exp(-x)); 292 } 293 LogPlus(Weight w,Weight v)294 Weight LogPlus(Weight w, Weight v) { 295 double f1 = to_log_weight_(w).Value(); 296 double f2 = to_log_weight_(v).Value(); 297 if (f1 > f2) 298 return to_weight_(f2 - LogPosExp(f1 - f2)); 299 else 300 return to_weight_(f1 - LogPosExp(f2 - f1)); 301 } 302 LogPlus(double f1,Weight v)303 double LogPlus(double f1, Weight v) { 304 double f2 = to_log_weight_(v).Value(); 305 if (f1 == FloatLimits<double>::PosInfinity()) 306 return f2; 307 else if (f1 > f2) 308 return f2 - LogPosExp(f1 - f2); 309 else 310 return f1 - LogPosExp(f2 - f1); 311 } 312 LogMinus(double f1,double f2)313 Weight LogMinus(double f1, double f2) { 314 if (f1 >= f2) { 315 FSTERROR() << "FastLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1 316 << " and f2 = " << f2; 317 error_ = true; 318 return Weight::NoWeight(); 319 } 320 if (f2 == FloatLimits<double>::PosInfinity()) 321 return to_weight_(f1); 322 else 323 return to_weight_(f1 - LogMinusExp(f2 - f1)); 324 } 325 326 WeightConvert<Weight, Log64Weight> to_log_weight_; 327 WeightConvert<Log64Weight, Weight> to_weight_; 328 329 ssize_t arc_limit_; // Minimum # of arcs to pre-compute state 330 ssize_t arc_period_; // Save cumulative weights per 'arc_period_'. 331 bool init_; // Cumulative weights initialized? 332 FastLogAccumulatorData *data_; 333 double *state_weights_; 334 bool error_; 335 336 void operator=(const FastLogAccumulator<A> &); // Disallow 337 }; 338 339 340 // Stores shareable data for cache log accumulator copies. 341 // All copies share the same cache. 342 template <class A> 343 class CacheLogAccumulatorData { 344 public: 345 typedef A Arc; 346 typedef typename A::StateId StateId; 347 typedef typename A::Weight Weight; 348 CacheLogAccumulatorData(bool gc,size_t gc_limit)349 CacheLogAccumulatorData(bool gc, size_t gc_limit) 350 : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {} 351 ~CacheLogAccumulatorData()352 ~CacheLogAccumulatorData() { 353 for(typename unordered_map<StateId, CacheState>::iterator it = cache_.begin(); 354 it != cache_.end(); 355 ++it) 356 delete it->second.weights; 357 } 358 CacheDisabled()359 bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; } 360 GetWeights(StateId s)361 vector<double> *GetWeights(StateId s) { 362 typename unordered_map<StateId, CacheState>::iterator it = cache_.find(s); 363 if (it != cache_.end()) { 364 it->second.recent = true; 365 return it->second.weights; 366 } else { 367 return 0; 368 } 369 } 370 AddWeights(StateId s,vector<double> * weights)371 void AddWeights(StateId s, vector<double> *weights) { 372 if (cache_gc_ && cache_size_ >= cache_limit_) 373 GC(false); 374 cache_.insert(make_pair(s, CacheState(weights, true))); 375 if (cache_gc_) 376 cache_size_ += weights->capacity() * sizeof(double); 377 } 378 RefCount()379 int RefCount() const { return ref_count_.count(); } IncrRefCount()380 int IncrRefCount() { return ref_count_.Incr(); } DecrRefCount()381 int DecrRefCount() { return ref_count_.Decr(); } 382 383 private: 384 // Cached information for a given state. 385 struct CacheState { 386 vector<double>* weights; // Accumulated weights for this state. 387 bool recent; // Has this state been accessed since last GC? 388 CacheStateCacheState389 CacheState(vector<double> *w, bool r) : weights(w), recent(r) {} 390 }; 391 392 // Garbage collect: Delete from cache states that have not been 393 // accessed since the last GC ('free_recent = false') until 394 // 'cache_size_' is 2/3 of 'cache_limit_'. If it does not free enough 395 // memory, start deleting recently accessed states. GC(bool free_recent)396 void GC(bool free_recent) { 397 size_t cache_target = (2 * cache_limit_)/3 + 1; 398 typename unordered_map<StateId, CacheState>::iterator it = cache_.begin(); 399 while (it != cache_.end() && cache_size_ > cache_target) { 400 CacheState &cs = it->second; 401 if (free_recent || !cs.recent) { 402 cache_size_ -= cs.weights->capacity() * sizeof(double); 403 delete cs.weights; 404 cache_.erase(it++); 405 } else { 406 cs.recent = false; 407 ++it; 408 } 409 } 410 if (!free_recent && cache_size_ > cache_target) 411 GC(true); 412 } 413 414 unordered_map<StateId, CacheState> cache_; // Cache 415 bool cache_gc_; // Enable garbage collection 416 size_t cache_limit_; // # of bytes cached 417 size_t cache_size_; // # of bytes allowed before GC 418 RefCounter ref_count_; 419 420 DISALLOW_COPY_AND_ASSIGN(CacheLogAccumulatorData); 421 }; 422 423 // This class accumulates arc weights using the log semiring Plus() 424 // has a WeightConvert specialization to and from log64 weights. It 425 // is similar to the FastLogAccumator. However here, the accumulated 426 // weights are pre-computed and stored only for the states that are 427 // visited. The member function Init(fst) has to be called to setup 428 // this accumulator. 429 template <class A> 430 class CacheLogAccumulator { 431 public: 432 typedef A Arc; 433 typedef typename A::StateId StateId; 434 typedef typename A::Weight Weight; 435 436 explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false, 437 size_t gc_limit = 10 * 1024 * 1024) arc_limit_(arc_limit)438 : arc_limit_(arc_limit), fst_(0), data_( 439 new CacheLogAccumulatorData<A>(gc, gc_limit)), s_(kNoStateId), 440 error_(false) {} 441 CacheLogAccumulator(const CacheLogAccumulator<A> & acc)442 CacheLogAccumulator(const CacheLogAccumulator<A> &acc) 443 : arc_limit_(acc.arc_limit_), fst_(acc.fst_ ? acc.fst_->Copy() : 0), 444 data_(acc.data_), s_(kNoStateId), error_(acc.error_) { 445 data_->IncrRefCount(); 446 } 447 ~CacheLogAccumulator()448 ~CacheLogAccumulator() { 449 if (fst_) 450 delete fst_; 451 if (!data_->DecrRefCount()) 452 delete data_; 453 } 454 455 // Arg 'arc_limit' specifies minimum # of arcs to pre-compute state. 456 void Init(const Fst<A> &fst, bool copy = false) { 457 if (copy) { 458 delete fst_; 459 } else if (fst_) { 460 FSTERROR() << "CacheLogAccumulator: initialization error."; 461 error_ = true; 462 return; 463 } 464 fst_ = fst.Copy(); 465 } 466 467 void SetState(StateId s, int depth = 0) { 468 if (s == s_) 469 return; 470 s_ = s; 471 472 if (data_->CacheDisabled() || error_) { 473 weights_ = 0; 474 return; 475 } 476 477 if (!fst_) { 478 FSTERROR() << "CacheLogAccumulator::SetState: incorrectly initialized."; 479 error_ = true; 480 weights_ = 0; 481 return; 482 } 483 484 weights_ = data_->GetWeights(s); 485 if ((weights_ == 0) && (fst_->NumArcs(s) >= arc_limit_)) { 486 weights_ = new vector<double>; 487 weights_->reserve(fst_->NumArcs(s) + 1); 488 weights_->push_back(FloatLimits<double>::PosInfinity()); 489 data_->AddWeights(s, weights_); 490 } 491 } 492 Sum(Weight w,Weight v)493 Weight Sum(Weight w, Weight v) { 494 return LogPlus(w, v); 495 } 496 497 template <class Iterator> Sum(Weight w,Iterator * aiter,ssize_t begin,ssize_t end)498 Weight Sum(Weight w, Iterator *aiter, ssize_t begin, 499 ssize_t end) { 500 if (weights_ == 0) { 501 Weight sum = w; 502 aiter->Seek(begin); 503 for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) 504 sum = LogPlus(sum, aiter->Value().weight); 505 return sum; 506 } else { 507 if (weights_->size() <= end) 508 for (aiter->Seek(weights_->size() - 1); 509 weights_->size() <= end; 510 aiter->Next()) 511 weights_->push_back(LogPlus(weights_->back(), 512 aiter->Value().weight)); 513 return LogPlus(w, LogMinus((*weights_)[end], (*weights_)[begin])); 514 } 515 } 516 517 template <class Iterator> LowerBound(double w,Iterator * aiter)518 size_t LowerBound(double w, Iterator *aiter) { 519 if (weights_ != 0) { 520 return lower_bound(weights_->begin() + 1, 521 weights_->end(), 522 w, 523 std::greater<double>()) 524 - weights_->begin() - 1; 525 } else { 526 size_t n = 0; 527 double x = FloatLimits<double>::PosInfinity(); 528 for(aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) { 529 x = LogPlus(x, aiter->Value().weight); 530 if (x < w) break; 531 } 532 return n; 533 } 534 } 535 Error()536 bool Error() const { return error_; } 537 538 private: LogPosExp(double x)539 double LogPosExp(double x) { 540 return x == FloatLimits<double>::PosInfinity() ? 541 0.0 : log(1.0F + exp(-x)); 542 } 543 LogMinusExp(double x)544 double LogMinusExp(double x) { 545 return x == FloatLimits<double>::PosInfinity() ? 546 0.0 : log(1.0F - exp(-x)); 547 } 548 LogPlus(Weight w,Weight v)549 Weight LogPlus(Weight w, Weight v) { 550 double f1 = to_log_weight_(w).Value(); 551 double f2 = to_log_weight_(v).Value(); 552 if (f1 > f2) 553 return to_weight_(f2 - LogPosExp(f1 - f2)); 554 else 555 return to_weight_(f1 - LogPosExp(f2 - f1)); 556 } 557 LogPlus(double f1,Weight v)558 double LogPlus(double f1, Weight v) { 559 double f2 = to_log_weight_(v).Value(); 560 if (f1 == FloatLimits<double>::PosInfinity()) 561 return f2; 562 else if (f1 > f2) 563 return f2 - LogPosExp(f1 - f2); 564 else 565 return f1 - LogPosExp(f2 - f1); 566 } 567 LogMinus(double f1,double f2)568 Weight LogMinus(double f1, double f2) { 569 if (f1 >= f2) { 570 FSTERROR() << "CacheLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1 571 << " and f2 = " << f2; 572 error_ = true; 573 return Weight::NoWeight(); 574 } 575 if (f2 == FloatLimits<double>::PosInfinity()) 576 return to_weight_(f1); 577 else 578 return to_weight_(f1 - LogMinusExp(f2 - f1)); 579 } 580 581 WeightConvert<Weight, Log64Weight> to_log_weight_; 582 WeightConvert<Log64Weight, Weight> to_weight_; 583 584 ssize_t arc_limit_; // Minimum # of arcs to cache a state 585 vector<double> *weights_; // Accumulated weights for cur. state 586 const Fst<A>* fst_; // Input fst 587 CacheLogAccumulatorData<A> *data_; // Cache data 588 StateId s_; // Current state 589 bool error_; 590 591 void operator=(const CacheLogAccumulator<A> &); // Disallow 592 }; 593 594 595 // Stores shareable data for replace accumulator copies. 596 template <class Accumulator, class T> 597 class ReplaceAccumulatorData { 598 public: 599 typedef typename Accumulator::Arc Arc; 600 typedef typename Arc::StateId StateId; 601 typedef typename Arc::Label Label; 602 typedef T StateTable; 603 typedef typename T::StateTuple StateTuple; 604 ReplaceAccumulatorData()605 ReplaceAccumulatorData() : state_table_(0) {} 606 ReplaceAccumulatorData(const vector<Accumulator * > & accumulators)607 ReplaceAccumulatorData(const vector<Accumulator*> &accumulators) 608 : state_table_(0), accumulators_(accumulators) {} 609 ~ReplaceAccumulatorData()610 ~ReplaceAccumulatorData() { 611 for (size_t i = 0; i < fst_array_.size(); ++i) 612 delete fst_array_[i]; 613 for (size_t i = 0; i < accumulators_.size(); ++i) 614 delete accumulators_[i]; 615 } 616 Init(const vector<pair<Label,const Fst<Arc> * >> & fst_tuples,const StateTable * state_table)617 void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples, 618 const StateTable *state_table) { 619 state_table_ = state_table; 620 accumulators_.resize(fst_tuples.size()); 621 for (size_t i = 0; i < accumulators_.size(); ++i) { 622 if (!accumulators_[i]) 623 accumulators_[i] = new Accumulator; 624 accumulators_[i]->Init(*(fst_tuples[i].second)); 625 fst_array_.push_back(fst_tuples[i].second->Copy()); 626 } 627 } 628 GetTuple(StateId s)629 const StateTuple &GetTuple(StateId s) const { 630 return state_table_->Tuple(s); 631 } 632 GetAccumulator(size_t i)633 Accumulator *GetAccumulator(size_t i) { return accumulators_[i]; } 634 GetFst(size_t i)635 const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i]; } 636 RefCount()637 int RefCount() const { return ref_count_.count(); } IncrRefCount()638 int IncrRefCount() { return ref_count_.Incr(); } DecrRefCount()639 int DecrRefCount() { return ref_count_.Decr(); } 640 641 private: 642 const T * state_table_; 643 vector<Accumulator*> accumulators_; 644 vector<const Fst<Arc>*> fst_array_; 645 RefCounter ref_count_; 646 647 DISALLOW_COPY_AND_ASSIGN(ReplaceAccumulatorData); 648 }; 649 650 // This class accumulates weights in a ReplaceFst. The 'Init' method 651 // takes as input the argument used to build the ReplaceFst and the 652 // ReplaceFst state table. It uses accumulators of type 'Accumulator' 653 // in the underlying FSTs. 654 template <class Accumulator, 655 class T = DefaultReplaceStateTable<typename Accumulator::Arc> > 656 class ReplaceAccumulator { 657 public: 658 typedef typename Accumulator::Arc Arc; 659 typedef typename Arc::StateId StateId; 660 typedef typename Arc::Label Label; 661 typedef typename Arc::Weight Weight; 662 typedef T StateTable; 663 typedef typename T::StateTuple StateTuple; 664 ReplaceAccumulator()665 ReplaceAccumulator() 666 : init_(false), data_(new ReplaceAccumulatorData<Accumulator, T>()), 667 error_(false) {} 668 ReplaceAccumulator(const vector<Accumulator * > & accumulators)669 ReplaceAccumulator(const vector<Accumulator*> &accumulators) 670 : init_(false), 671 data_(new ReplaceAccumulatorData<Accumulator, T>(accumulators)), 672 error_(false) {} 673 ReplaceAccumulator(const ReplaceAccumulator<Accumulator,T> & acc)674 ReplaceAccumulator(const ReplaceAccumulator<Accumulator, T> &acc) 675 : init_(acc.init_), data_(acc.data_), error_(acc.error_) { 676 if (!init_) 677 FSTERROR() << "ReplaceAccumulator: can't copy unintialized accumulator"; 678 data_->IncrRefCount(); 679 } 680 ~ReplaceAccumulator()681 ~ReplaceAccumulator() { 682 if (!data_->DecrRefCount()) 683 delete data_; 684 } 685 686 // Does not take ownership of the state table, the state table 687 // is own by the ReplaceFst Init(const vector<pair<Label,const Fst<Arc> * >> & fst_tuples,const StateTable * state_table)688 void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples, 689 const StateTable *state_table) { 690 init_ = true; 691 data_->Init(fst_tuples, state_table); 692 } 693 SetState(StateId s)694 void SetState(StateId s) { 695 if (!init_) { 696 FSTERROR() << "ReplaceAccumulator::SetState: incorrectly initialized."; 697 error_ = true; 698 return; 699 } 700 StateTuple tuple = data_->GetTuple(s); 701 fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based 702 data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state); 703 if ((tuple.prefix_id != 0) && 704 (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) { 705 offset_ = 1; 706 offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state); 707 } else { 708 offset_ = 0; 709 offset_weight_ = Weight::Zero(); 710 } 711 } 712 Sum(Weight w,Weight v)713 Weight Sum(Weight w, Weight v) { 714 if (error_) return Weight::NoWeight(); 715 return data_->GetAccumulator(fst_id_)->Sum(w, v); 716 } 717 718 template <class ArcIterator> Sum(Weight w,ArcIterator * aiter,ssize_t begin,ssize_t end)719 Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, 720 ssize_t end) { 721 if (error_) return Weight::NoWeight(); 722 Weight sum = begin == end ? Weight::Zero() 723 : data_->GetAccumulator(fst_id_)->Sum( 724 w, aiter, begin ? begin - offset_ : 0, end - offset_); 725 if (begin == 0 && end != 0 && offset_ > 0) 726 sum = Sum(offset_weight_, sum); 727 return sum; 728 } 729 Error()730 bool Error() const { return error_; } 731 732 private: 733 bool init_; 734 ReplaceAccumulatorData<Accumulator, T> *data_; 735 Label fst_id_; 736 size_t offset_; 737 Weight offset_weight_; 738 bool error_; 739 740 void operator=(const ReplaceAccumulator<Accumulator, T> &); // Disallow 741 }; 742 743 } // namespace fst 744 745 #endif // FST_LIB_ACCUMULATOR_H__ 746