1 // replace.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Recursively replace Fst arcs with other Fst(s) returning a PDT.
20 
21 #ifndef FST_EXTENSIONS_PDT_REPLACE_H__
22 #define FST_EXTENSIONS_PDT_REPLACE_H__
23 
24 #include <tr1/unordered_map>
25 using std::tr1::unordered_map;
26 using std::tr1::unordered_multimap;
27 
28 #include <fst/replace.h>
29 
30 namespace fst {
31 
32 // Hash to paren IDs
33 template <typename S>
34 struct ReplaceParenHash {
operatorReplaceParenHash35   size_t operator()(const pair<size_t, S> &p) const {
36     return p.first + p.second * kPrime;
37   }
38  private:
39   static const size_t kPrime = 7853;
40 };
41 
42 template <typename S> const size_t ReplaceParenHash<S>::kPrime;
43 
44 // Builds a pushdown transducer (PDT) from an RTN specification
45 // identical to that in fst/lib/replace.h. The result is a PDT
46 // encoded as the FST 'ofst' where some transitions are labeled with
47 // open or close parentheses. To be interpreted as a PDT, the parens
48 // must balance on a path (see PdtExpand()). The open/close
49 // parenthesis label pairs are returned in 'parens'.
50 template <class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,vector<pair<typename Arc::Label,typename Arc::Label>> * parens,typename Arc::Label root)51 void Replace(const vector<pair<typename Arc::Label,
52              const Fst<Arc>* > >& ifst_array,
53              MutableFst<Arc> *ofst,
54              vector<pair<typename Arc::Label,
55              typename Arc::Label> > *parens,
56              typename Arc::Label root) {
57   typedef typename Arc::Label Label;
58   typedef typename Arc::StateId StateId;
59   typedef typename Arc::Weight Weight;
60 
61   ofst->DeleteStates();
62   parens->clear();
63 
64   unordered_map<Label, size_t> label2id;
65   for (size_t i = 0; i < ifst_array.size(); ++i)
66     label2id[ifst_array[i].first] = i;
67 
68   Label max_label = kNoLabel;
69   size_t max_non_term_count = 0;
70 
71   // Queue of non-terminals to replace
72   deque<size_t> non_term_queue;
73   // Map of non-terminals to replace to count
74   unordered_map<Label, size_t> non_term_map;
75   non_term_queue.push_back(root);
76   non_term_map[root] = 1;;
77 
78   // PDT state corr. to ith replace FST start state.
79   vector<StateId> fst_start(ifst_array.size(), kNoLabel);
80   // PDT state, weight pairs corr. to ith replace FST final state & weights.
81   vector< vector<pair<StateId, Weight> > > fst_final(ifst_array.size());
82 
83   // Builds single Fst combining all referenced input Fsts. Leaves in the
84   // non-termnals for now.  Tabulate the PDT states that correspond to
85   // the start and final states of the input Fsts.
86   for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) {
87     Label label = non_term_queue.front();
88     non_term_queue.pop_front();
89     size_t fst_id = label2id[label];
90 
91     const Fst<Arc> *ifst = ifst_array[fst_id].second;
92     for (StateIterator< Fst<Arc> > siter(*ifst);
93          !siter.Done(); siter.Next()) {
94       StateId is = siter.Value();
95       StateId os = ofst->AddState();
96       if (is == ifst->Start()) {
97         fst_start[fst_id] = os;
98         if (label == root)
99           ofst->SetStart(os);
100       }
101       if (ifst->Final(is) != Weight::Zero()) {
102         if (label == root)
103           ofst->SetFinal(os, ifst->Final(is));
104         fst_final[fst_id].push_back(make_pair(os, ifst->Final(is)));
105       }
106       for (ArcIterator< Fst<Arc> > aiter(*ifst, is);
107            !aiter.Done(); aiter.Next()) {
108         Arc arc = aiter.Value();
109         if (max_label == kNoLabel || arc.olabel > max_label)
110           max_label = arc.olabel;
111         typename unordered_map<Label, size_t>::const_iterator it =
112             label2id.find(arc.olabel);
113         if (it != label2id.end()) {
114           size_t nfst_id = it->second;
115           if (ifst_array[nfst_id].second->Start() == -1)
116             continue;
117           size_t count = non_term_map[arc.olabel]++;
118           if (count == 0)
119             non_term_queue.push_back(arc.olabel);
120           if (count > max_non_term_count)
121             max_non_term_count = count;
122         }
123         arc.nextstate += soff;
124         ofst->AddArc(os, arc);
125       }
126     }
127   }
128 
129   // Changes each non-terminal transition to an open parenthesis
130   // transition redirected to the PDT state that corresponds to the
131   // start state of the input FST for the non-terminal. Adds close parenthesis
132   // transitions from the PDT states corr. to the final states of the
133   // input FST for the non-terminal to the former destination state of the
134   // non-terminal transition.
135 
136   typedef MutableArcIterator< MutableFst<Arc> > MIter;
137   typedef unordered_map<pair<size_t, StateId >, size_t,
138                    ReplaceParenHash<StateId> > ParenMap;
139 
140   // Parenthesis pair ID per fst, state pair.
141   ParenMap paren_map;
142   // # of parenthesis pairs per fst.
143   vector<size_t> nparens(ifst_array.size(), 0);
144   // Initial open parenthesis label
145   Label first_open_paren = max_label + 1;
146   Label first_close_paren = max_label + max_non_term_count + 1;
147 
148   for (StateIterator< Fst<Arc> > siter(*ofst);
149        !siter.Done(); siter.Next()) {
150     StateId os = siter.Value();
151     MIter *aiter = new MIter(ofst, os);
152     for (size_t n = 0; !aiter->Done(); aiter->Next(), ++n) {
153       Arc arc = aiter->Value();
154       typename unordered_map<Label, size_t>::const_iterator lit =
155           label2id.find(arc.olabel);
156       if (lit != label2id.end()) {
157         size_t nfst_id = lit->second;
158 
159         // Get parentheses. Ensures distinct parenthesis pair per
160         // non-terminal and destination state but otherwise reuses them.
161         Label open_paren = kNoLabel, close_paren = kNoLabel;
162         pair<size_t, StateId> paren_key(nfst_id, arc.nextstate);
163         typename ParenMap::const_iterator pit = paren_map.find(paren_key);
164         if (pit != paren_map.end()) {
165           size_t paren_id = pit->second;
166           open_paren = (*parens)[paren_id].first;
167           close_paren = (*parens)[paren_id].second;
168         } else {
169           size_t paren_id = nparens[nfst_id]++;
170           open_paren = first_open_paren + paren_id;
171           close_paren = first_close_paren + paren_id;
172           paren_map[paren_key] = paren_id;
173           if (paren_id >= parens->size())
174             parens->push_back(make_pair(open_paren, close_paren));
175         }
176 
177         // Sets open parenthesis.
178         Arc sarc(open_paren, open_paren, arc.weight, fst_start[nfst_id]);
179         aiter->SetValue(sarc);
180 
181         // Adds close parentheses.
182         for (size_t i = 0; i < fst_final[nfst_id].size(); ++i) {
183           pair<StateId, Weight> &p = fst_final[nfst_id][i];
184           Arc farc(close_paren, close_paren, p.second, arc.nextstate);
185 
186           ofst->AddArc(p.first, farc);
187           if (os == p.first) {  // Invalidated iterator
188             delete aiter;
189             aiter = new MIter(ofst, os);
190             aiter->Seek(n);
191           }
192         }
193       }
194     }
195     delete aiter;
196   }
197 }
198 
199 }  // namespace fst
200 
201 #endif  // FST_EXTENSIONS_PDT_REPLACE_H__
202