1 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 // 14 // Copyright 2005-2010 Google, Inc. 15 // Author: jpr@google.com (Jake Ratkiewicz) 16 17 #ifndef FST_SCRIPT_SHORTEST_DISTANCE_H_ 18 #define FST_SCRIPT_SHORTEST_DISTANCE_H_ 19 20 #include <vector> 21 using std::vector; 22 23 #include <fst/script/arg-packs.h> 24 #include <fst/script/fst-class.h> 25 #include <fst/script/weight-class.h> 26 #include <fst/script/prune.h> // for ArcFilterType 27 #include <fst/queue.h> // for QueueType 28 #include <fst/shortest-distance.h> 29 30 namespace fst { 31 namespace script { 32 33 enum ArcFilterType { ANY_ARC_FILTER, EPSILON_ARC_FILTER, 34 INPUT_EPSILON_ARC_FILTER, OUTPUT_EPSILON_ARC_FILTER }; 35 36 // See nlp/fst/lib/shortest-distance.h for the template options class 37 // that this one shadows 38 struct ShortestDistanceOptions { 39 const QueueType queue_type; 40 const ArcFilterType arc_filter_type; 41 const int64 source; 42 const float delta; 43 const bool first_path; 44 ShortestDistanceOptionsShortestDistanceOptions45 ShortestDistanceOptions(QueueType qt, ArcFilterType aft, int64 s, 46 float d) 47 : queue_type(qt), arc_filter_type(aft), source(s), delta(d), 48 first_path(false) { } 49 }; 50 51 52 53 // 1 54 typedef args::Package<const FstClass &, vector<WeightClass> *, 55 const ShortestDistanceOptions &> ShortestDistanceArgs1; 56 57 template<class Queue, class Arc, class ArcFilter> 58 struct QueueConstructor { 59 // template<class Arc, class ArcFilter> ConstructQueueConstructor60 static Queue *Construct(const Fst<Arc> &, 61 const vector<typename Arc::Weight> *) { 62 return new Queue(); 63 } 64 }; 65 66 // Specializations to deal with AutoQueue, NaturalShortestFirstQueue, 67 // and TopOrderQueue's different constructors 68 template<class Arc, class ArcFilter> 69 struct QueueConstructor<AutoQueue<typename Arc::StateId>, Arc, ArcFilter> { 70 // template<class Arc, class ArcFilter> 71 static AutoQueue<typename Arc::StateId> *Construct( 72 const Fst<Arc> &fst, 73 const vector<typename Arc::Weight> *distance) { 74 return new AutoQueue<typename Arc::StateId>(fst, distance, ArcFilter()); 75 } 76 }; 77 78 template<class Arc, class ArcFilter> 79 struct QueueConstructor<NaturalShortestFirstQueue<typename Arc::StateId, 80 typename Arc::Weight>, 81 Arc, ArcFilter> { 82 // template<class Arc, class ArcFilter> 83 static NaturalShortestFirstQueue<typename Arc::StateId, typename Arc::Weight> 84 *Construct(const Fst<Arc> &fst, 85 const vector<typename Arc::Weight> *distance) { 86 return new NaturalShortestFirstQueue<typename Arc::StateId, 87 typename Arc::Weight>(*distance); 88 } 89 }; 90 91 template<class Arc, class ArcFilter> 92 struct QueueConstructor<TopOrderQueue<typename Arc::StateId>, Arc, ArcFilter> { 93 // template<class Arc, class ArcFilter> 94 static TopOrderQueue<typename Arc::StateId> *Construct( 95 const Fst<Arc> &fst, const vector<typename Arc::Weight> *weights) { 96 return new TopOrderQueue<typename Arc::StateId>(fst, ArcFilter()); 97 } 98 }; 99 100 101 template<class Arc, class Queue> 102 void ShortestDistanceHelper(ShortestDistanceArgs1 *args) { 103 const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); 104 const ShortestDistanceOptions &opts = args->arg3; 105 106 vector<typename Arc::Weight> weights; 107 108 switch (opts.arc_filter_type) { 109 case ANY_ARC_FILTER: { 110 Queue *queue = 111 QueueConstructor<Queue, Arc, AnyArcFilter<Arc> >::Construct( 112 fst, &weights); 113 fst::ShortestDistanceOptions<Arc, Queue, AnyArcFilter<Arc> > sdopts( 114 queue, AnyArcFilter<Arc>(), opts.source, opts.delta); 115 ShortestDistance(fst, &weights, sdopts); 116 delete queue; 117 break; 118 } 119 case EPSILON_ARC_FILTER: { 120 Queue *queue = 121 QueueConstructor<Queue, Arc, AnyArcFilter<Arc> >::Construct( 122 fst, &weights); 123 fst::ShortestDistanceOptions<Arc, Queue, 124 EpsilonArcFilter<Arc> > sdopts( 125 queue, EpsilonArcFilter<Arc>(), opts.source, opts.delta); 126 ShortestDistance(fst, &weights, sdopts); 127 delete queue; 128 break; 129 } 130 case INPUT_EPSILON_ARC_FILTER: { 131 Queue *queue = 132 QueueConstructor<Queue, Arc, InputEpsilonArcFilter<Arc> >::Construct( 133 fst, &weights); 134 fst::ShortestDistanceOptions<Arc, Queue, 135 InputEpsilonArcFilter<Arc> > sdopts( 136 queue, InputEpsilonArcFilter<Arc>(), opts.source, opts.delta); 137 ShortestDistance(fst, &weights, sdopts); 138 delete queue; 139 break; 140 } 141 case OUTPUT_EPSILON_ARC_FILTER: { 142 Queue *queue = 143 QueueConstructor<Queue, Arc, 144 OutputEpsilonArcFilter<Arc> >::Construct( 145 fst, &weights); 146 fst::ShortestDistanceOptions<Arc, Queue, 147 OutputEpsilonArcFilter<Arc> > sdopts( 148 queue, OutputEpsilonArcFilter<Arc>(), opts.source, opts.delta); 149 ShortestDistance(fst, &weights, sdopts); 150 delete queue; 151 break; 152 } 153 } 154 155 // Copy the weights back 156 args->arg2->resize(weights.size()); 157 for (unsigned i = 0; i < weights.size(); ++i) { 158 (*args->arg2)[i] = WeightClass(weights[i]); 159 } 160 } 161 162 template<class Arc> 163 void ShortestDistance(ShortestDistanceArgs1 *args) { 164 const ShortestDistanceOptions &opts = args->arg3; 165 typedef typename Arc::StateId StateId; 166 typedef typename Arc::Weight Weight; 167 168 // Must consider (opts.queue_type x opts.filter_type) options 169 switch (opts.queue_type) { 170 default: 171 FSTERROR() << "Unknown queue type." << opts.queue_type; 172 173 case AUTO_QUEUE: 174 ShortestDistanceHelper<Arc, AutoQueue<StateId> >(args); 175 return; 176 177 case FIFO_QUEUE: 178 ShortestDistanceHelper<Arc, FifoQueue<StateId> >(args); 179 return; 180 181 case LIFO_QUEUE: 182 ShortestDistanceHelper<Arc, LifoQueue<StateId> >(args); 183 return; 184 185 case SHORTEST_FIRST_QUEUE: 186 ShortestDistanceHelper<Arc, 187 NaturalShortestFirstQueue<StateId, Weight> >(args); 188 return; 189 190 case STATE_ORDER_QUEUE: 191 ShortestDistanceHelper<Arc, StateOrderQueue<StateId> >(args); 192 return; 193 194 case TOP_ORDER_QUEUE: 195 ShortestDistanceHelper<Arc, TopOrderQueue<StateId> >(args); 196 return; 197 } 198 } 199 200 // 2 201 typedef args::Package<const FstClass&, vector<WeightClass>*, 202 bool, double> ShortestDistanceArgs2; 203 204 template<class Arc> 205 void ShortestDistance(ShortestDistanceArgs2 *args) { 206 const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); 207 vector<typename Arc::Weight> distance; 208 209 ShortestDistance(fst, &distance, args->arg3, args->arg4); 210 211 // convert the typed weights back into weightclass 212 vector<WeightClass> *retval = args->arg2; 213 retval->resize(distance.size()); 214 215 for (unsigned i = 0; i < distance.size(); ++i) { 216 (*retval)[i] = WeightClass(distance[i]); 217 } 218 } 219 220 // 3 221 typedef args::WithReturnValue<WeightClass, 222 const FstClass &> ShortestDistanceArgs3; 223 224 template<class Arc> 225 void ShortestDistance(ShortestDistanceArgs3 *args) { 226 const Fst<Arc> &fst = *(args->args.GetFst<Arc>()); 227 228 args->retval = WeightClass(ShortestDistance(fst)); 229 } 230 231 232 // 1 233 void ShortestDistance(const FstClass &fst, vector<WeightClass> *distance, 234 const ShortestDistanceOptions &opts); 235 236 // 2 237 void ShortestDistance(const FstClass &ifst, vector<WeightClass> *distance, 238 bool reverse = false, double delta = fst::kDelta); 239 240 #ifndef SWIG 241 // 3 242 WeightClass ShortestDistance(const FstClass &ifst); 243 #endif 244 245 } // namespace script 246 } // namespace fst 247 248 249 250 #endif // FST_SCRIPT_SHORTEST_DISTANCE_H_ 251