1 // state-reachable.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Class to determine whether a given (final) state can be reached from some
20 // other given state.
21 
22 #ifndef FST_LIB_STATE_REACHABLE_H__
23 #define FST_LIB_STATE_REACHABLE_H__
24 
25 #include <vector>
26 using std::vector;
27 
28 #include <fst/dfs-visit.h>
29 #include <fst/fst.h>
30 #include <fst/interval-set.h>
31 
32 
33 namespace fst {
34 
35 // Computes the (final) states reachable from a given state in an FST.
36 // After this visitor has been called, a final state f can be reached
37 // from a state s iff (*isets)[s].Member(state2index[f]) is true, where
38 // (*isets[s]) is a set of half-open inteval of final state indices
39 // and state2index[f] maps from a final state to its index.
40 //
41 // If state2index is empty, it is filled-in with suitable indices.
42 // If it is non-empty, those indices are used; in this case, the
43 // final states must have out-degree 0.
44 template <class A, typename I = typename A::StateId>
45 class IntervalReachVisitor {
46  public:
47   typedef typename A::StateId StateId;
48   typedef typename A::Label Label;
49   typedef typename A::Weight Weight;
50   typedef typename IntervalSet<I>::Interval Interval;
51 
IntervalReachVisitor(const Fst<A> & fst,vector<IntervalSet<I>> * isets,vector<I> * state2index)52   IntervalReachVisitor(const Fst<A> &fst,
53                        vector< IntervalSet<I> > *isets,
54                        vector<I> *state2index)
55       : fst_(fst),
56         isets_(isets),
57         state2index_(state2index),
58         index_(state2index->empty() ? 1 : -1),
59         error_(false) {
60     isets_->clear();
61   }
62 
InitVisit(const Fst<A> & fst)63   void InitVisit(const Fst<A> &fst) { error_ = false; }
64 
InitState(StateId s,StateId r)65   bool InitState(StateId s, StateId r) {
66     while (isets_->size() <= s)
67       isets_->push_back(IntervalSet<Label>());
68     while (state2index_->size() <= s)
69       state2index_->push_back(-1);
70 
71     if (fst_.Final(s) != Weight::Zero()) {
72       // Create tree interval
73       vector<Interval> *intervals = (*isets_)[s].Intervals();
74       if (index_ < 0) {  // Use state2index_ map to set index
75         if (fst_.NumArcs(s) > 0) {
76           FSTERROR() << "IntervalReachVisitor: state2index map must be empty "
77                      << "for this FST";
78           error_ = true;
79           return false;
80         }
81         I index = (*state2index_)[s];
82         if (index < 0) {
83           FSTERROR() << "IntervalReachVisitor: state2index map incomplete";
84           error_ = true;
85           return false;
86         }
87         intervals->push_back(Interval(index, index + 1));
88       } else {           // Use pre-order index
89         intervals->push_back(Interval(index_, index_ + 1));
90         (*state2index_)[s] = index_++;
91       }
92     }
93     return true;
94   }
95 
TreeArc(StateId s,const A & arc)96   bool TreeArc(StateId s, const A &arc) {
97     return true;
98   }
99 
BackArc(StateId s,const A & arc)100   bool BackArc(StateId s, const A &arc) {
101     FSTERROR() << "IntervalReachVisitor: cyclic input";
102     error_ = true;
103     return false;
104   }
105 
ForwardOrCrossArc(StateId s,const A & arc)106   bool ForwardOrCrossArc(StateId s, const A &arc) {
107     // Non-tree interval
108     (*isets_)[s].Union((*isets_)[arc.nextstate]);
109     return true;
110   }
111 
FinishState(StateId s,StateId p,const A * arc)112   void FinishState(StateId s, StateId p, const A *arc) {
113     if (index_ >= 0 && fst_.Final(s) != Weight::Zero()) {
114       vector<Interval> *intervals = (*isets_)[s].Intervals();
115       (*intervals)[0].end = index_;      // Update tree interval end
116     }
117     (*isets_)[s].Normalize();
118     if (p != kNoStateId)
119       (*isets_)[p].Union((*isets_)[s]);  // Propagate intervals to parent
120   }
121 
FinishVisit()122   void FinishVisit() {}
123 
Error()124   bool Error() const { return error_; }
125 
126  private:
127   const Fst<A> &fst_;
128   vector< IntervalSet<I> > *isets_;
129   vector<I> *state2index_;
130   I index_;
131   bool error_;
132 };
133 
134 
135 // Tests reachability of final states from a given state. To test for
136 // reachability from a state s, first do SetState(s). Then a final
137 // state f can be reached from state s of FST iff Reach(f) is true.
138 template <class A, typename I = typename A::StateId>
139 class StateReachable {
140  public:
141   typedef A Arc;
142   typedef I Index;
143   typedef typename A::StateId StateId;
144   typedef typename A::Label Label;
145   typedef typename A::Weight Weight;
146   typedef typename IntervalSet<I>::Interval Interval;
147 
StateReachable(const Fst<A> & fst)148   StateReachable(const Fst<A> &fst)
149       : error_(false) {
150     IntervalReachVisitor<Arc> reach_visitor(fst, &isets_, &state2index_);
151     DfsVisit(fst, &reach_visitor);
152     if (reach_visitor.Error()) error_ = true;
153   }
154 
StateReachable(const StateReachable<A> & reachable)155   StateReachable(const StateReachable<A> &reachable) {
156     FSTERROR() << "Copy constructor for state reachable class "
157                << "not yet implemented.";
158     error_ = true;
159   }
160 
161   // Set current state.
SetState(StateId s)162   void SetState(StateId s) { s_ = s; }
163 
164   // Can reach this label from current state?
Reach(StateId s)165   bool Reach(StateId s) {
166     if (s >= state2index_.size())
167       return false;
168 
169     I i =  state2index_[s];
170     if (i < 0) {
171       FSTERROR() << "StateReachable: state non-final: " << s;
172       error_ = true;
173       return false;
174     }
175     return isets_[s_].Member(i);
176   }
177 
178   // Access to the state-to-index mapping. Unassigned states have index -1.
State2Index()179   vector<I> &State2Index() { return state2index_; }
180 
181   // Access to the interval sets. These specify the reachability
182   // to the final states as intervals of the final state indices.
IntervalSets()183   const vector< IntervalSet<I> > &IntervalSets() { return isets_; }
184 
Error()185   bool Error() const { return error_; }
186 
187  private:
188   StateId s_;                                 // Current state
189   vector< IntervalSet<I> > isets_;            // Interval sets per state
190   vector<I> state2index_;                     // Finds index for a final state
191   bool error_;
192 
193   void operator=(const StateReachable<A> &);  // Disallow
194 };
195 
196 }  // namespace fst
197 
198 #endif  // FST_LIB_STATE_REACHABLE_H__
199