1 
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2005-2010 Google, Inc.
15 // Author: jpr@google.com (Jake Ratkiewicz)
16 // Convenience file for including all PDT operations at once, and/or
17 // registering them for new arc types.
18 
19 #ifndef FST_EXTENSIONS_PDT_PDTSCRIPT_H_
20 #define FST_EXTENSIONS_PDT_PDTSCRIPT_H_
21 
22 #include <utility>
23 using std::pair; using std::make_pair;
24 #include <vector>
25 using std::vector;
26 
27 #include <fst/compose.h>  // for ComposeOptions
28 #include <fst/util.h>
29 
30 #include <fst/script/fst-class.h>
31 #include <fst/script/arg-packs.h>
32 #include <fst/script/shortest-path.h>
33 
34 #include <fst/extensions/pdt/compose.h>
35 #include <fst/extensions/pdt/expand.h>
36 #include <fst/extensions/pdt/info.h>
37 #include <fst/extensions/pdt/replace.h>
38 #include <fst/extensions/pdt/reverse.h>
39 #include <fst/extensions/pdt/shortest-path.h>
40 
41 
42 namespace fst {
43 namespace script {
44 
45 // PDT COMPOSE
46 
47 typedef args::Package<const FstClass &,
48                       const FstClass &,
49                       const vector<pair<int64, int64> >&,
50                       MutableFstClass *,
51                       const PdtComposeOptions &,
52                       bool> PdtComposeArgs;
53 
54 template<class Arc>
PdtCompose(PdtComposeArgs * args)55 void PdtCompose(PdtComposeArgs *args) {
56   const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>());
57   const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>());
58   MutableFst<Arc> *ofst = args->arg4->GetMutableFst<Arc>();
59 
60   vector<pair<typename Arc::Label, typename Arc::Label> > parens(
61       args->arg3.size());
62 
63   for (size_t i = 0; i < parens.size(); ++i) {
64     parens[i].first = args->arg3[i].first;
65     parens[i].second = args->arg3[i].second;
66   }
67 
68   if (args->arg6) {
69     Compose(ifst1, parens, ifst2, ofst, args->arg5);
70   } else {
71     Compose(ifst1, ifst2, parens, ofst, args->arg5);
72   }
73 }
74 
75 void PdtCompose(const FstClass & ifst1,
76                 const FstClass & ifst2,
77                 const vector<pair<int64, int64> > &parens,
78                 MutableFstClass *ofst,
79                 const PdtComposeOptions &copts,
80                 bool left_pdt);
81 
82 // PDT EXPAND
83 
84 struct PdtExpandOptions {
85   bool connect;
86   bool keep_parentheses;
87   WeightClass weight_threshold;
88 
89   PdtExpandOptions(bool c = true, bool k = false,
90                    WeightClass w = WeightClass::Zero())
connectPdtExpandOptions91       : connect(c), keep_parentheses(k), weight_threshold(w) {}
92 };
93 
94 typedef args::Package<const FstClass &,
95                       const vector<pair<int64, int64> >&,
96                       MutableFstClass *, PdtExpandOptions> PdtExpandArgs;
97 
98 template<class Arc>
PdtExpand(PdtExpandArgs * args)99 void PdtExpand(PdtExpandArgs *args) {
100   const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
101   MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
102 
103   vector<pair<typename Arc::Label, typename Arc::Label> > parens(
104       args->arg2.size());
105   for (size_t i = 0; i < parens.size(); ++i) {
106     parens[i].first = args->arg2[i].first;
107     parens[i].second = args->arg2[i].second;
108   }
109   Expand(fst, parens, ofst,
110          ExpandOptions<Arc>(
111              args->arg4.connect, args->arg4.keep_parentheses,
112              *(args->arg4.weight_threshold.GetWeight<typename Arc::Weight>())));
113 }
114 
115 void PdtExpand(const FstClass &ifst,
116                const vector<pair<int64, int64> > &parens,
117                MutableFstClass *ofst, const PdtExpandOptions &opts);
118 
119 void PdtExpand(const FstClass &ifst,
120                const vector<pair<int64, int64> > &parens,
121                MutableFstClass *ofst, bool connect);
122 
123 // PDT REPLACE
124 
125 typedef args::Package<const vector<pair<int64, const FstClass*> > &,
126                       MutableFstClass *,
127                       vector<pair<int64, int64> > *,
128                       const int64 &> PdtReplaceArgs;
129 template<class Arc>
PdtReplace(PdtReplaceArgs * args)130 void PdtReplace(PdtReplaceArgs *args) {
131   vector<pair<typename Arc::Label, const Fst<Arc> *> > tuples(
132       args->arg1.size());
133   for (size_t i = 0; i < tuples.size(); ++i) {
134     tuples[i].first = args->arg1[i].first;
135     tuples[i].second = (args->arg1[i].second)->GetFst<Arc>();
136   }
137   MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>();
138   vector<pair<typename Arc::Label, typename Arc::Label> > parens(
139       args->arg3->size());
140 
141   for (size_t i = 0; i < parens.size(); ++i) {
142     parens[i].first = args->arg3->at(i).first;
143     parens[i].second = args->arg3->at(i).second;
144   }
145   Replace(tuples, ofst, &parens, args->arg4);
146 
147   // now copy parens back
148   args->arg3->resize(parens.size());
149   for (size_t i = 0; i < parens.size(); ++i) {
150     (*args->arg3)[i].first = parens[i].first;
151     (*args->arg3)[i].second = parens[i].second;
152   }
153 }
154 
155 void PdtReplace(const vector<pair<int64, const FstClass*> > &fst_tuples,
156                 MutableFstClass *ofst,
157                 vector<pair<int64, int64> > *parens,
158                 const int64 &root);
159 
160 // PDT REVERSE
161 
162 typedef args::Package<const FstClass &,
163                       const vector<pair<int64, int64> >&,
164                       MutableFstClass *> PdtReverseArgs;
165 
166 template<class Arc>
PdtReverse(PdtReverseArgs * args)167 void PdtReverse(PdtReverseArgs *args) {
168   const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
169   MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
170 
171   vector<pair<typename Arc::Label, typename Arc::Label> > parens(
172       args->arg2.size());
173   for (size_t i = 0; i < parens.size(); ++i) {
174     parens[i].first = args->arg2[i].first;
175     parens[i].second = args->arg2[i].second;
176   }
177   Reverse(fst, parens, ofst);
178 }
179 
180 void PdtReverse(const FstClass &ifst,
181                 const vector<pair<int64, int64> > &parens,
182                 MutableFstClass *ofst);
183 
184 
185 // PDT SHORTESTPATH
186 
187 struct PdtShortestPathOptions {
188   QueueType queue_type;
189   bool keep_parentheses;
190   bool path_gc;
191 
192   PdtShortestPathOptions(QueueType qt = FIFO_QUEUE,
193                          bool kp = false, bool gc = true)
queue_typePdtShortestPathOptions194       : queue_type(qt), keep_parentheses(kp), path_gc(gc) {}
195 };
196 
197 typedef args::Package<const FstClass &,
198                       const vector<pair<int64, int64> >&,
199                       MutableFstClass *,
200                       const PdtShortestPathOptions &> PdtShortestPathArgs;
201 
202 template<class Arc>
PdtShortestPath(PdtShortestPathArgs * args)203 void PdtShortestPath(PdtShortestPathArgs *args) {
204   typedef typename Arc::StateId StateId;
205   typedef typename Arc::Label Label;
206   typedef typename Arc::Weight Weight;
207 
208   const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
209   MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
210   const PdtShortestPathOptions &opts = args->arg4;
211 
212 
213   vector<pair<Label, Label> > parens(args->arg2.size());
214   for (size_t i = 0; i < parens.size(); ++i) {
215     parens[i].first = args->arg2[i].first;
216     parens[i].second = args->arg2[i].second;
217   }
218 
219   switch (opts.queue_type) {
220     default:
221       FSTERROR() << "Unknown queue type: " << opts.queue_type;
222     case FIFO_QUEUE: {
223       typedef FifoQueue<StateId> Queue;
224       fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
225                                                          opts.path_gc);
226       ShortestPath(fst, parens, ofst, spopts);
227       return;
228     }
229     case LIFO_QUEUE: {
230       typedef LifoQueue<StateId> Queue;
231       fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
232                                                          opts.path_gc);
233       ShortestPath(fst, parens, ofst, spopts);
234       return;
235     }
236     case STATE_ORDER_QUEUE: {
237       typedef StateOrderQueue<StateId> Queue;
238       fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
239                                                          opts.path_gc);
240       ShortestPath(fst, parens, ofst, spopts);
241       return;
242     }
243   }
244 }
245 
246 void PdtShortestPath(const FstClass &ifst,
247                      const vector<pair<int64, int64> > &parens,
248                      MutableFstClass *ofst,
249                      const PdtShortestPathOptions &opts =
250                      PdtShortestPathOptions());
251 
252 // PRINT INFO
253 
254 typedef args::Package<const FstClass &,
255                       const vector<pair<int64, int64> > &> PrintPdtInfoArgs;
256 
257 template<class Arc>
PrintPdtInfo(PrintPdtInfoArgs * args)258 void PrintPdtInfo(PrintPdtInfoArgs *args) {
259   const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
260   vector<pair<typename Arc::Label, typename Arc::Label> > parens(
261       args->arg2.size());
262   for (size_t i = 0; i < parens.size(); ++i) {
263     parens[i].first = args->arg2[i].first;
264     parens[i].second = args->arg2[i].second;
265   }
266   PdtInfo<Arc> pdtinfo(fst, parens);
267   PrintPdtInfo(pdtinfo);
268 }
269 
270 void PrintPdtInfo(const FstClass &ifst,
271                   const vector<pair<int64, int64> > &parens);
272 
273 }  // namespace script
274 }  // namespace fst
275 
276 
277 #define REGISTER_FST_PDT_OPERATIONS(ArcType)                                \
278   REGISTER_FST_OPERATION(PdtCompose, ArcType, PdtComposeArgs);              \
279   REGISTER_FST_OPERATION(PdtExpand, ArcType, PdtExpandArgs);                \
280   REGISTER_FST_OPERATION(PdtReplace, ArcType, PdtReplaceArgs);              \
281   REGISTER_FST_OPERATION(PdtReverse, ArcType, PdtReverseArgs);              \
282   REGISTER_FST_OPERATION(PdtShortestPath, ArcType, PdtShortestPathArgs);    \
283   REGISTER_FST_OPERATION(PrintPdtInfo, ArcType, PrintPdtInfoArgs)
284 #endif  // FST_EXTENSIONS_PDT_PDTSCRIPT_H_
285