1 // shortest-distance.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: allauzen@google.com (Cyril Allauzen)
17 //
18 // \file
19 // Functions and classes to find shortest distance in an FST.
20
21 #ifndef FST_LIB_SHORTEST_DISTANCE_H__
22 #define FST_LIB_SHORTEST_DISTANCE_H__
23
24 #include <deque>
25 using std::deque;
26 #include <vector>
27 using std::vector;
28
29 #include <fst/arcfilter.h>
30 #include <fst/cache.h>
31 #include <fst/queue.h>
32 #include <fst/reverse.h>
33 #include <fst/test-properties.h>
34
35
36 namespace fst {
37
38 template <class Arc, class Queue, class ArcFilter>
39 struct ShortestDistanceOptions {
40 typedef typename Arc::StateId StateId;
41
42 Queue *state_queue; // Queue discipline used; owned by caller
43 ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph)
44 StateId source; // If kNoStateId, use the Fst's initial state
45 float delta; // Determines the degree of convergence required
46 bool first_path; // For a semiring with the path property (o.w.
47 // undefined), compute the shortest-distances along
48 // along the first path to a final state found
49 // by the algorithm. That path is the shortest-path
50 // only if the FST has a unique final state (or all
51 // the final states have the same final weight), the
52 // queue discipline is shortest-first and all the
53 // weights in the FST are between One() and Zero()
54 // according to NaturalLess.
55
56 ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId,
57 float d = kDelta)
state_queueShortestDistanceOptions58 : state_queue(q), arc_filter(filt), source(src), delta(d),
59 first_path(false) {}
60 };
61
62
63 // Computation state of the shortest-distance algorithm. Reusable
64 // information is maintained across calls to member function
65 // ShortestDistance(source) when 'retain' is true for improved
66 // efficiency when calling multiple times from different source states
67 // (e.g., in epsilon removal). Contrary to usual conventions, 'fst'
68 // may not be freed before this class. Vector 'distance' should not be
69 // modified by the user between these calls.
70 // The Error() method returns true if an error was encountered.
71 template<class Arc, class Queue, class ArcFilter>
72 class ShortestDistanceState {
73 public:
74 typedef typename Arc::StateId StateId;
75 typedef typename Arc::Weight Weight;
76
ShortestDistanceState(const Fst<Arc> & fst,vector<Weight> * distance,const ShortestDistanceOptions<Arc,Queue,ArcFilter> & opts,bool retain)77 ShortestDistanceState(
78 const Fst<Arc> &fst,
79 vector<Weight> *distance,
80 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts,
81 bool retain)
82 : fst_(fst), distance_(distance), state_queue_(opts.state_queue),
83 arc_filter_(opts.arc_filter), delta_(opts.delta),
84 first_path_(opts.first_path), retain_(retain), source_id_(0),
85 error_(false) {
86 distance_->clear();
87 }
88
~ShortestDistanceState()89 ~ShortestDistanceState() {}
90
91 void ShortestDistance(StateId source);
92
Error()93 bool Error() const { return error_; }
94
95 private:
96 const Fst<Arc> &fst_;
97 vector<Weight> *distance_;
98 Queue *state_queue_;
99 ArcFilter arc_filter_;
100 float delta_;
101 bool first_path_;
102 bool retain_; // Retain and reuse information across calls
103
104 vector<Weight> rdistance_; // Relaxation distance.
105 vector<bool> enqueued_; // Is state enqueued?
106 vector<StateId> sources_; // Source ID for ith state in 'distance_',
107 // 'rdistance_', and 'enqueued_' if retained.
108 StateId source_id_; // Unique ID characterizing each call to SD
109
110 bool error_;
111 };
112
113 // Compute the shortest distance. If 'source' is kNoStateId, use
114 // the initial state of the Fst.
115 template <class Arc, class Queue, class ArcFilter>
ShortestDistance(StateId source)116 void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance(
117 StateId source) {
118 if (fst_.Start() == kNoStateId) {
119 if (fst_.Properties(kError, false)) error_ = true;
120 return;
121 }
122
123 if (!(Weight::Properties() & kRightSemiring)) {
124 FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
125 << Weight::Type();
126 error_ = true;
127 return;
128 }
129
130 if (first_path_ && !(Weight::Properties() & kPath)) {
131 FSTERROR() << "ShortestDistance: first_path option disallowed when "
132 << "Weight does not have the path property: "
133 << Weight::Type();
134 error_ = true;
135 return;
136 }
137
138 state_queue_->Clear();
139
140 if (!retain_) {
141 distance_->clear();
142 rdistance_.clear();
143 enqueued_.clear();
144 }
145
146 if (source == kNoStateId)
147 source = fst_.Start();
148
149 while (distance_->size() <= source) {
150 distance_->push_back(Weight::Zero());
151 rdistance_.push_back(Weight::Zero());
152 enqueued_.push_back(false);
153 }
154 if (retain_) {
155 while (sources_.size() <= source)
156 sources_.push_back(kNoStateId);
157 sources_[source] = source_id_;
158 }
159 (*distance_)[source] = Weight::One();
160 rdistance_[source] = Weight::One();
161 enqueued_[source] = true;
162
163 state_queue_->Enqueue(source);
164
165 while (!state_queue_->Empty()) {
166 StateId s = state_queue_->Head();
167 state_queue_->Dequeue();
168 while (distance_->size() <= s) {
169 distance_->push_back(Weight::Zero());
170 rdistance_.push_back(Weight::Zero());
171 enqueued_.push_back(false);
172 }
173 if (first_path_ && (fst_.Final(s) != Weight::Zero()))
174 break;
175 enqueued_[s] = false;
176 Weight r = rdistance_[s];
177 rdistance_[s] = Weight::Zero();
178 for (ArcIterator< Fst<Arc> > aiter(fst_, s);
179 !aiter.Done();
180 aiter.Next()) {
181 const Arc &arc = aiter.Value();
182 if (!arc_filter_(arc))
183 continue;
184 while (distance_->size() <= arc.nextstate) {
185 distance_->push_back(Weight::Zero());
186 rdistance_.push_back(Weight::Zero());
187 enqueued_.push_back(false);
188 }
189 if (retain_) {
190 while (sources_.size() <= arc.nextstate)
191 sources_.push_back(kNoStateId);
192 if (sources_[arc.nextstate] != source_id_) {
193 (*distance_)[arc.nextstate] = Weight::Zero();
194 rdistance_[arc.nextstate] = Weight::Zero();
195 enqueued_[arc.nextstate] = false;
196 sources_[arc.nextstate] = source_id_;
197 }
198 }
199 Weight &nd = (*distance_)[arc.nextstate];
200 Weight &nr = rdistance_[arc.nextstate];
201 Weight w = Times(r, arc.weight);
202 if (!ApproxEqual(nd, Plus(nd, w), delta_)) {
203 nd = Plus(nd, w);
204 nr = Plus(nr, w);
205 if (!nd.Member() || !nr.Member()) {
206 error_ = true;
207 return;
208 }
209 if (!enqueued_[arc.nextstate]) {
210 state_queue_->Enqueue(arc.nextstate);
211 enqueued_[arc.nextstate] = true;
212 } else {
213 state_queue_->Update(arc.nextstate);
214 }
215 }
216 }
217 }
218 ++source_id_;
219 if (fst_.Properties(kError, false)) error_ = true;
220 }
221
222
223 // Shortest-distance algorithm: this version allows fine control
224 // via the options argument. See below for a simpler interface.
225 //
226 // This computes the shortest distance from the 'opts.source' state to
227 // each visited state S and stores the value in the 'distance' vector.
228 // An unvisited state S has distance Zero(), which will be stored in
229 // the 'distance' vector if S is less than the maximum visited state.
230 // The state queue discipline, arc filter, and convergence delta are
231 // taken in the options argument.
232 // The 'distance' vector will contain a unique element for which
233 // Member() is false if an error was encountered.
234 //
235 // The weights must must be right distributive and k-closed (i.e., 1 +
236 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
237 //
238 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for
239 // Shortest-Distance Problems", Journal of Automata, Languages and
240 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm
241 // depends on the properties of the semiring and the queue discipline
242 // used. Refer to the paper for more details.
243 template<class Arc, class Queue, class ArcFilter>
ShortestDistance(const Fst<Arc> & fst,vector<typename Arc::Weight> * distance,const ShortestDistanceOptions<Arc,Queue,ArcFilter> & opts)244 void ShortestDistance(
245 const Fst<Arc> &fst,
246 vector<typename Arc::Weight> *distance,
247 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
248
249 ShortestDistanceState<Arc, Queue, ArcFilter>
250 sd_state(fst, distance, opts, false);
251 sd_state.ShortestDistance(opts.source);
252 if (sd_state.Error()) {
253 distance->clear();
254 distance->resize(1, Arc::Weight::NoWeight());
255 }
256 }
257
258 // Shortest-distance algorithm: simplified interface. See above for a
259 // version that allows finer control.
260 //
261 // If 'reverse' is false, this computes the shortest distance from the
262 // initial state to each state S and stores the value in the
263 // 'distance' vector. If 'reverse' is true, this computes the shortest
264 // distance from each state to the final states. An unvisited state S
265 // has distance Zero(), which will be stored in the 'distance' vector
266 // if S is less than the maximum visited state. The state queue
267 // discipline is automatically-selected.
268 // The 'distance' vector will contain a unique element for which
269 // Member() is false if an error was encountered.
270 //
271 // The weights must must be right (left) distributive if reverse is
272 // false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 +
273 // x + x^2 + ... + x^k).
274 //
275 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for
276 // Shortest-Distance Problems", Journal of Automata, Languages and
277 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm
278 // depends on the properties of the semiring and the queue discipline
279 // used. Refer to the paper for more details.
280 template<class Arc>
281 void ShortestDistance(const Fst<Arc> &fst,
282 vector<typename Arc::Weight> *distance,
283 bool reverse = false,
284 float delta = kDelta) {
285 typedef typename Arc::StateId StateId;
286 typedef typename Arc::Weight Weight;
287
288 if (!reverse) {
289 AnyArcFilter<Arc> arc_filter;
290 AutoQueue<StateId> state_queue(fst, distance, arc_filter);
291 ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> >
292 opts(&state_queue, arc_filter);
293 opts.delta = delta;
294 ShortestDistance(fst, distance, opts);
295 } else {
296 typedef ReverseArc<Arc> ReverseArc;
297 typedef typename ReverseArc::Weight ReverseWeight;
298 AnyArcFilter<ReverseArc> rarc_filter;
299 VectorFst<ReverseArc> rfst;
300 Reverse(fst, &rfst);
301 vector<ReverseWeight> rdistance;
302 AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
303 ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>,
304 AnyArcFilter<ReverseArc> >
305 ropts(&state_queue, rarc_filter);
306 ropts.delta = delta;
307 ShortestDistance(rfst, &rdistance, ropts);
308 distance->clear();
309 if (rdistance.size() == 1 && !rdistance[0].Member()) {
310 distance->resize(1, Arc::Weight::NoWeight());
311 return;
312 }
313 while (distance->size() < rdistance.size() - 1)
314 distance->push_back(rdistance[distance->size() + 1].Reverse());
315 }
316 }
317
318
319 // Return the sum of the weight of all successful paths in an FST, i.e.,
320 // the shortest-distance from the initial state to the final states.
321 // Returns a weight such that Member() is false if an error was encountered.
322 template <class Arc>
323 typename Arc::Weight ShortestDistance(const Fst<Arc> &fst, float delta = kDelta) {
324 typedef typename Arc::Weight Weight;
325 typedef typename Arc::StateId StateId;
326 vector<Weight> distance;
327 if (Weight::Properties() & kRightSemiring) {
328 ShortestDistance(fst, &distance, false, delta);
329 if (distance.size() == 1 && !distance[0].Member())
330 return Arc::Weight::NoWeight();
331 Weight sum = Weight::Zero();
332 for (StateId s = 0; s < distance.size(); ++s)
333 sum = Plus(sum, Times(distance[s], fst.Final(s)));
334 return sum;
335 } else {
336 ShortestDistance(fst, &distance, true, delta);
337 StateId s = fst.Start();
338 if (distance.size() == 1 && !distance[0].Member())
339 return Arc::Weight::NoWeight();
340 return s != kNoStateId && s < distance.size() ?
341 distance[s] : Weight::Zero();
342 }
343 }
344
345
346 } // namespace fst
347
348 #endif // FST_LIB_SHORTEST_DISTANCE_H__
349