1 // label_reachable.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley@google.com (Michael Riley) 17 // 18 // \file 19 // Class to determine if a non-epsilon label can be read as the 20 // first non-epsilon symbol along some path from a given state. 21 22 23 #ifndef FST_LIB_LABEL_REACHABLE_H__ 24 #define FST_LIB_LABEL_REACHABLE_H__ 25 26 #include <tr1/unordered_map> 27 using std::tr1::unordered_map; 28 using std::tr1::unordered_multimap; 29 #include <vector> 30 using std::vector; 31 32 #include <fst/accumulator.h> 33 #include <fst/arcsort.h> 34 #include <fst/interval-set.h> 35 #include <fst/state-reachable.h> 36 #include <fst/vector-fst.h> 37 38 39 namespace fst { 40 41 // Stores shareable data for label reachable class copies. 42 template <typename L> 43 class LabelReachableData { 44 public: 45 typedef L Label; 46 typedef typename IntervalSet<L>::Interval Interval; 47 48 explicit LabelReachableData(bool reach_input, bool keep_relabel_data = true) reach_input_(reach_input)49 : reach_input_(reach_input), 50 keep_relabel_data_(keep_relabel_data), 51 have_relabel_data_(true), 52 final_label_(kNoLabel) {} 53 ~LabelReachableData()54 ~LabelReachableData() {} 55 ReachInput()56 bool ReachInput() const { return reach_input_; } 57 IntervalSets()58 vector< IntervalSet<L> > *IntervalSets() { return &isets_; } 59 Label2Index()60 unordered_map<L, L> *Label2Index() { 61 if (!have_relabel_data_) 62 FSTERROR() << "LabelReachableData: no relabeling data"; 63 return &label2index_; 64 } 65 FinalLabel()66 Label FinalLabel() { 67 if (final_label_ == kNoLabel) 68 final_label_ = label2index_[kNoLabel]; 69 return final_label_; 70 } 71 Read(istream & istrm)72 static LabelReachableData<L> *Read(istream &istrm) { 73 LabelReachableData<L> *data = new LabelReachableData<L>(); 74 75 ReadType(istrm, &data->reach_input_); 76 ReadType(istrm, &data->keep_relabel_data_); 77 data->have_relabel_data_ = data->keep_relabel_data_; 78 if (data->keep_relabel_data_) 79 ReadType(istrm, &data->label2index_); 80 ReadType(istrm, &data->final_label_); 81 ReadType(istrm, &data->isets_); 82 return data; 83 } 84 Write(ostream & ostrm)85 bool Write(ostream &ostrm) { 86 WriteType(ostrm, reach_input_); 87 WriteType(ostrm, keep_relabel_data_); 88 if (keep_relabel_data_) 89 WriteType(ostrm, label2index_); 90 WriteType(ostrm, FinalLabel()); 91 WriteType(ostrm, isets_); 92 return true; 93 } 94 RefCount()95 int RefCount() const { return ref_count_.count(); } IncrRefCount()96 int IncrRefCount() { return ref_count_.Incr(); } DecrRefCount()97 int DecrRefCount() { return ref_count_.Decr(); } 98 99 private: LabelReachableData()100 LabelReachableData() {} 101 102 bool reach_input_; // Input or output labels considered? 103 bool keep_relabel_data_; // Save label2index_ to file? 104 bool have_relabel_data_; // Using label2index_? 105 Label final_label_; // Final label 106 RefCounter ref_count_; // Reference count. 107 unordered_map<L, L> label2index_; // Finds index for a label. 108 vector<IntervalSet <L> > isets_; // Interval sets per state. 109 110 DISALLOW_COPY_AND_ASSIGN(LabelReachableData); 111 }; 112 113 114 // Tests reachability of labels from a given state. If reach_input = 115 // true, then input labels are considered, o.w. output labels are 116 // considered. To test for reachability from a state s, first do 117 // SetState(s). Then a label l can be reached from state s of FST f 118 // iff Reach(r) is true where r = Relabel(l). The relabeling is 119 // required to ensure a compact representation of the reachable 120 // labels. 121 122 // The whole FST can be relabeled instead with Relabel(&f, 123 // reach_input) so that the test Reach(r) applies directly to the 124 // labels of the transformed FST f. The relabeled FST will also be 125 // sorted appropriately for composition. 126 // 127 // Reachablity of a final state from state s (via an epsilon path) 128 // can be tested with ReachFinal(); 129 // 130 // Reachability can also be tested on the set of labels specified by 131 // an arc iterator, useful for FST composition. In particular, 132 // Reach(aiter, ...) is true if labels on the input (output) side of 133 // the transitions of the arc iterator, when iter_input is true 134 // (false), can be reached from the state s. The iterator labels must 135 // have already been relabeled. 136 // 137 // With the arc iterator test of reachability, the begin position, end 138 // position and accumulated arc weight of the matches can be 139 // returned. The optional template argument controls how reachable arc 140 // weights are accumulated. The default uses the semiring 141 // Plus(). Alternative ones can be used to distribute the weights in 142 // composition in various ways. 143 template <class A, class S = DefaultAccumulator<A> > 144 class LabelReachable { 145 public: 146 typedef A Arc; 147 typedef typename A::StateId StateId; 148 typedef typename A::Label Label; 149 typedef typename A::Weight Weight; 150 typedef typename IntervalSet<Label>::Interval Interval; 151 152 LabelReachable(const Fst<A> &fst, bool reach_input, S *s = 0, 153 bool keep_relabel_data = true) fst_(new VectorFst<Arc> (fst))154 : fst_(new VectorFst<Arc>(fst)), 155 s_(kNoStateId), 156 data_(new LabelReachableData<Label>(reach_input, keep_relabel_data)), 157 accumulator_(s ? s : new S()), 158 ncalls_(0), 159 nintervals_(0), 160 error_(false) { 161 StateId ins = fst_->NumStates(); 162 TransformFst(); 163 FindIntervals(ins); 164 delete fst_; 165 } 166 167 explicit LabelReachable(LabelReachableData<Label> *data, S *s = 0) 168 : fst_(0), 169 s_(kNoStateId), 170 data_(data), 171 accumulator_(s ? s : new S()), 172 ncalls_(0), 173 nintervals_(0), 174 error_(false) { 175 data_->IncrRefCount(); 176 } 177 LabelReachable(const LabelReachable<A,S> & reachable)178 LabelReachable(const LabelReachable<A, S> &reachable) : 179 fst_(0), 180 s_(kNoStateId), 181 data_(reachable.data_), 182 accumulator_(new S(*reachable.accumulator_)), 183 ncalls_(0), 184 nintervals_(0), 185 error_(reachable.error_) { 186 data_->IncrRefCount(); 187 } 188 ~LabelReachable()189 ~LabelReachable() { 190 if (!data_->DecrRefCount()) 191 delete data_; 192 delete accumulator_; 193 if (ncalls_ > 0) { 194 VLOG(2) << "# of calls: " << ncalls_; 195 VLOG(2) << "# of intervals/call: " << (nintervals_ / ncalls_); 196 } 197 } 198 199 // Relabels w.r.t labels that give compact label sets. Relabel(Label label)200 Label Relabel(Label label) { 201 if (label == 0 || error_) 202 return label; 203 unordered_map<Label, Label> &label2index = *data_->Label2Index(); 204 Label &relabel = label2index[label]; 205 if (!relabel) // Add new label 206 relabel = label2index.size() + 1; 207 return relabel; 208 } 209 210 // Relabels Fst w.r.t to labels that give compact label sets. Relabel(MutableFst<Arc> * fst,bool relabel_input)211 void Relabel(MutableFst<Arc> *fst, bool relabel_input) { 212 for (StateIterator< MutableFst<Arc> > siter(*fst); 213 !siter.Done(); siter.Next()) { 214 StateId s = siter.Value(); 215 for (MutableArcIterator< MutableFst<Arc> > aiter(fst, s); 216 !aiter.Done(); 217 aiter.Next()) { 218 Arc arc = aiter.Value(); 219 if (relabel_input) 220 arc.ilabel = Relabel(arc.ilabel); 221 else 222 arc.olabel = Relabel(arc.olabel); 223 aiter.SetValue(arc); 224 } 225 } 226 if (relabel_input) { 227 ArcSort(fst, ILabelCompare<Arc>()); 228 fst->SetInputSymbols(0); 229 } else { 230 ArcSort(fst, OLabelCompare<Arc>()); 231 fst->SetOutputSymbols(0); 232 } 233 } 234 235 // Returns relabeling pairs (cf. relabel.h::Relabel()). 236 // If 'avoid_collisions' is true, extra pairs are added to 237 // ensure no collisions when relabeling automata that have 238 // labels unseen here. 239 void RelabelPairs(vector<pair<Label, Label> > *pairs, 240 bool avoid_collisions = false) { 241 pairs->clear(); 242 unordered_map<Label, Label> &label2index = *data_->Label2Index(); 243 // Maps labels to their new values in [1, label2index().size()] 244 for (typename unordered_map<Label, Label>::const_iterator 245 it = label2index.begin(); it != label2index.end(); ++it) 246 if (it->second != data_->FinalLabel()) 247 pairs->push_back(pair<Label, Label>(it->first, it->second)); 248 if (avoid_collisions) { 249 // Ensures any label in [1, label2index().size()] is mapped either 250 // by the above step or to label2index() + 1 (to avoid collisions). 251 for (int i = 1; i <= label2index.size(); ++i) { 252 typename unordered_map<Label, Label>::const_iterator 253 it = label2index.find(i); 254 if (it == label2index.end() || it->second == data_->FinalLabel()) 255 pairs->push_back(pair<Label, Label>(i, label2index.size() + 1)); 256 } 257 } 258 } 259 260 // Set current state. Optionally set state associated 261 // with arc iterator to be passed to Reach. 262 void SetState(StateId s, StateId aiter_s = kNoStateId) { 263 s_ = s; 264 if (aiter_s != kNoStateId) { 265 accumulator_->SetState(aiter_s); 266 if (accumulator_->Error()) error_ = true; 267 } 268 } 269 270 // Can reach this label from current state? 271 // Original labels must be transformed by the Relabel methods above. Reach(Label label)272 bool Reach(Label label) { 273 if (label == 0 || error_) 274 return false; 275 vector< IntervalSet<Label> > &isets = *data_->IntervalSets(); 276 return isets[s_].Member(label); 277 278 } 279 280 // Can reach final state (via epsilon transitions) from this state? ReachFinal()281 bool ReachFinal() { 282 if (error_) return false; 283 vector< IntervalSet<Label> > &isets = *data_->IntervalSets(); 284 return isets[s_].Member(data_->FinalLabel()); 285 } 286 287 // Initialize with secondary FST to be used with Reach(Iterator,...). 288 // If copy is true, then 'fst' is a copy of the FST used in the 289 // previous call to this method (useful to avoid unnecessary updates). 290 template <class F> 291 void ReachInit(const F &fst, bool copy = false) { 292 accumulator_->Init(fst, copy); 293 if (accumulator_->Error()) error_ = true; 294 } 295 296 // Can reach any arc iterator label between iterator positions 297 // aiter_begin and aiter_end? If aiter_input = true, then iterator 298 // input labels are considered, o.w. output labels are considered. 299 // Arc iterator labels must be transformed by the Relabel methods 300 // above. If compute_weight is true, user may call ReachWeight(). 301 template <class Iterator> Reach(Iterator * aiter,ssize_t aiter_begin,ssize_t aiter_end,bool aiter_input,bool compute_weight)302 bool Reach(Iterator *aiter, ssize_t aiter_begin, 303 ssize_t aiter_end, bool aiter_input, bool compute_weight) { 304 if (error_) return false; 305 vector< IntervalSet<Label> > &isets = *data_->IntervalSets(); 306 const vector<Interval> *intervals = isets[s_].Intervals(); 307 ++ncalls_; 308 nintervals_ += intervals->size(); 309 310 reach_begin_ = -1; 311 reach_end_ = -1; 312 reach_weight_ = Weight::Zero(); 313 314 uint32 flags = aiter->Flags(); // save flags to restore them on exit 315 aiter->SetFlags(kArcNoCache, kArcNoCache); // make caching optional 316 aiter->Seek(aiter_begin); 317 318 if (2 * (aiter_end - aiter_begin) < intervals->size()) { 319 // Check each arc against intervals. 320 // Set arc iterator flags to only compute the ilabel or olabel values, 321 // since they are the only values required for most of the arcs processed. 322 aiter->SetFlags(aiter_input ? kArcILabelValue : kArcOLabelValue, 323 kArcValueFlags); 324 Label reach_label = kNoLabel; 325 for (ssize_t aiter_pos = aiter_begin; 326 aiter_pos < aiter_end; aiter->Next(), ++aiter_pos) { 327 const A &arc = aiter->Value(); 328 Label label = aiter_input ? arc.ilabel : arc.olabel; 329 if (label == reach_label || Reach(label)) { 330 reach_label = label; 331 if (reach_begin_ < 0) 332 reach_begin_ = aiter_pos; 333 reach_end_ = aiter_pos + 1; 334 if (compute_weight) { 335 if (!(aiter->Flags() & kArcWeightValue)) { 336 // If the 'arc.weight' wasn't computed by the call 337 // to 'aiter->Value()' above, we need to call 338 // 'aiter->Value()' again after having set the arc iterator 339 // flags to compute the arc weight value. 340 aiter->SetFlags(kArcWeightValue, kArcValueFlags); 341 const A &arcb = aiter->Value(); 342 // Call the accumulator. 343 reach_weight_ = accumulator_->Sum(reach_weight_, arcb.weight); 344 // Only ilabel or olabel required to process the following 345 // arcs. 346 aiter->SetFlags(aiter_input ? kArcILabelValue : kArcOLabelValue, 347 kArcValueFlags); 348 } else { 349 // Call the accumulator. 350 reach_weight_ = accumulator_->Sum(reach_weight_, arc.weight); 351 } 352 } 353 } 354 } 355 } else { 356 // Check each interval against arcs 357 ssize_t begin_low, end_low = aiter_begin; 358 for (typename vector<Interval>::const_iterator 359 iiter = intervals->begin(); 360 iiter != intervals->end(); ++iiter) { 361 begin_low = LowerBound(aiter, end_low, aiter_end, 362 aiter_input, iiter->begin); 363 end_low = LowerBound(aiter, begin_low, aiter_end, 364 aiter_input, iiter->end); 365 if (end_low - begin_low > 0) { 366 if (reach_begin_ < 0) 367 reach_begin_ = begin_low; 368 reach_end_ = end_low; 369 if (compute_weight) { 370 aiter->SetFlags(kArcWeightValue, kArcValueFlags); 371 reach_weight_ = accumulator_->Sum(reach_weight_, aiter, 372 begin_low, end_low); 373 } 374 } 375 } 376 } 377 378 aiter->SetFlags(flags, kArcFlags); // restore original flag values 379 return reach_begin_ >= 0; 380 } 381 382 // Returns iterator position of first matching arc. ReachBegin()383 ssize_t ReachBegin() const { return reach_begin_; } 384 385 // Returns iterator position one past last matching arc. ReachEnd()386 ssize_t ReachEnd() const { return reach_end_; } 387 388 // Return the sum of the weights for matching arcs. 389 // Valid only if compute_weight was true in Reach() call. ReachWeight()390 Weight ReachWeight() const { return reach_weight_; } 391 392 // Access to the relabeling map. Excludes epsilon (0) label but 393 // includes kNoLabel that is used internally for super-final 394 // transitons. Label2Index()395 const unordered_map<Label, Label>& Label2Index() const { 396 return *data_->Label2Index(); 397 } 398 GetData()399 LabelReachableData<Label> *GetData() const { return data_; } 400 Error()401 bool Error() const { return error_ || accumulator_->Error(); } 402 403 private: 404 // Redirects labeled arcs (input or output labels determined by 405 // ReachInput()) to new label-specific final states. Each original 406 // final state is redirected via a transition labeled with kNoLabel 407 // to a new kNoLabel-specific final state. Creates super-initial 408 // state for all states with zero in-degree. TransformFst()409 void TransformFst() { 410 StateId ins = fst_->NumStates(); 411 StateId ons = ins; 412 413 vector<ssize_t> indeg(ins, 0); 414 415 // Redirects labeled arcs to new final states. 416 for (StateId s = 0; s < ins; ++s) { 417 for (MutableArcIterator< VectorFst<Arc> > aiter(fst_, s); 418 !aiter.Done(); 419 aiter.Next()) { 420 Arc arc = aiter.Value(); 421 Label label = data_->ReachInput() ? arc.ilabel : arc.olabel; 422 if (label) { 423 if (label2state_.find(label) == label2state_.end()) { 424 label2state_[label] = ons; 425 indeg.push_back(0); 426 ++ons; 427 } 428 arc.nextstate = label2state_[label]; 429 aiter.SetValue(arc); 430 } 431 ++indeg[arc.nextstate]; // Finds in-degrees for next step. 432 } 433 434 // Redirects final weights to new final state. 435 Weight final = fst_->Final(s); 436 if (final != Weight::Zero()) { 437 if (label2state_.find(kNoLabel) == label2state_.end()) { 438 label2state_[kNoLabel] = ons; 439 indeg.push_back(0); 440 ++ons; 441 } 442 Arc arc(kNoLabel, kNoLabel, final, label2state_[kNoLabel]); 443 fst_->AddArc(s, arc); 444 ++indeg[arc.nextstate]; // Finds in-degrees for next step. 445 446 fst_->SetFinal(s, Weight::Zero()); 447 } 448 } 449 450 // Add new final states to Fst. 451 while (fst_->NumStates() < ons) { 452 StateId s = fst_->AddState(); 453 fst_->SetFinal(s, Weight::One()); 454 } 455 456 // Creates a super-initial state for all states with zero in-degree. 457 StateId start = fst_->AddState(); 458 fst_->SetStart(start); 459 for (StateId s = 0; s < start; ++s) { 460 if (indeg[s] == 0) { 461 Arc arc(0, 0, Weight::One(), s); 462 fst_->AddArc(start, arc); 463 } 464 } 465 } 466 FindIntervals(StateId ins)467 void FindIntervals(StateId ins) { 468 StateReachable<A, Label> state_reachable(*fst_); 469 if (state_reachable.Error()) { 470 error_ = true; 471 return; 472 } 473 474 vector<Label> &state2index = state_reachable.State2Index(); 475 vector< IntervalSet<Label> > &isets = *data_->IntervalSets(); 476 isets = state_reachable.IntervalSets(); 477 isets.resize(ins); 478 479 unordered_map<Label, Label> &label2index = *data_->Label2Index(); 480 for (typename unordered_map<Label, StateId>::const_iterator 481 it = label2state_.begin(); 482 it != label2state_.end(); 483 ++it) { 484 Label l = it->first; 485 StateId s = it->second; 486 Label i = state2index[s]; 487 label2index[l] = i; 488 } 489 label2state_.clear(); 490 491 double nintervals = 0; 492 ssize_t non_intervals = 0; 493 for (ssize_t s = 0; s < ins; ++s) { 494 nintervals += isets[s].Size(); 495 if (isets[s].Size() > 1) { 496 ++non_intervals; 497 VLOG(3) << "state: " << s << " # of intervals: " << isets[s].Size(); 498 } 499 } 500 VLOG(2) << "# of states: " << ins; 501 VLOG(2) << "# of intervals: " << nintervals; 502 VLOG(2) << "# of intervals/state: " << nintervals/ins; 503 VLOG(2) << "# of non-interval states: " << non_intervals; 504 } 505 506 template <class Iterator> LowerBound(Iterator * aiter,ssize_t aiter_begin,ssize_t aiter_end,bool aiter_input,Label match_label)507 ssize_t LowerBound(Iterator *aiter, ssize_t aiter_begin, 508 ssize_t aiter_end, bool aiter_input, 509 Label match_label) const { 510 // Only need to compute the ilabel or olabel of arcs when 511 // performing the binary search. 512 aiter->SetFlags(aiter_input ? kArcILabelValue : kArcOLabelValue, 513 kArcValueFlags); 514 ssize_t low = aiter_begin; 515 ssize_t high = aiter_end; 516 while (low < high) { 517 ssize_t mid = (low + high) / 2; 518 aiter->Seek(mid); 519 Label label = aiter_input ? 520 aiter->Value().ilabel : aiter->Value().olabel; 521 if (label > match_label) { 522 high = mid; 523 } else if (label < match_label) { 524 low = mid + 1; 525 } else { 526 // Find first matching label (when non-deterministic) 527 for (ssize_t i = mid; i > low; --i) { 528 aiter->Seek(i - 1); 529 label = aiter_input ? aiter->Value().ilabel : aiter->Value().olabel; 530 if (label != match_label) { 531 aiter->Seek(i); 532 aiter->SetFlags(kArcValueFlags, kArcValueFlags); 533 return i; 534 } 535 } 536 aiter->SetFlags(kArcValueFlags, kArcValueFlags); 537 return low; 538 } 539 } 540 aiter->Seek(low); 541 aiter->SetFlags(kArcValueFlags, kArcValueFlags); 542 return low; 543 } 544 545 VectorFst<Arc> *fst_; 546 StateId s_; // Current state 547 unordered_map<Label, StateId> label2state_; // Finds final state for a label 548 549 ssize_t reach_begin_; // Iterator pos of first match 550 ssize_t reach_end_; // Iterator pos after last match 551 Weight reach_weight_; // Gives weight sum of arc iterator 552 // arcs with reachable labels. 553 LabelReachableData<Label> *data_; // Shareable data between copies 554 S *accumulator_; // Sums arc weights 555 556 double ncalls_; 557 double nintervals_; 558 bool error_; 559 560 void operator=(const LabelReachable<A, S> &); // Disallow 561 }; 562 563 } // namespace fst 564 565 #endif // FST_LIB_LABEL_REACHABLE_H__ 566