1 
2 // string.h
3 
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // Copyright 2005-2010 Google, Inc.
17 // Author: allauzen@google.com (Cyril Allauzen)
18 //
19 // \file
20 // Utilities to convert strings into FSTs.
21 //
22 
23 #ifndef FST_LIB_STRING_H_
24 #define FST_LIB_STRING_H_
25 
26 #include <fst/compact-fst.h>
27 #include <fst/icu.h>
28 #include <fst/mutable-fst.h>
29 
30 DECLARE_string(fst_field_separator);
31 
32 namespace fst {
33 
34 // Functor compiling a string in an FST
35 template <class A>
36 class StringCompiler {
37  public:
38   typedef A Arc;
39   typedef typename A::Label Label;
40   typedef typename A::Weight Weight;
41 
42   enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
43 
44   StringCompiler(TokenType type, const SymbolTable *syms = 0,
45                  Label unknown_label = kNoLabel,
46                  bool allow_negative = false)
token_type_(type)47       : token_type_(type), syms_(syms), unknown_label_(unknown_label),
48         allow_negative_(allow_negative) {}
49 
50   // Compile string 's' into FST 'fst'.
51   template <class F>
operator()52   bool operator()(const string &s, F *fst) const {
53     vector<Label> labels;
54     if (!ConvertStringToLabels(s, &labels))
55       return false;
56     Compile(labels, fst);
57     return true;
58   }
59 
60   template <class F>
operator()61   bool operator()(const string &s, F *fst, Weight w) const {
62     vector<Label> labels;
63     if (!ConvertStringToLabels(s, &labels))
64       return false;
65     Compile(labels, fst, w);
66     return true;
67   }
68 
69  private:
ConvertStringToLabels(const string & str,vector<Label> * labels)70   bool ConvertStringToLabels(const string &str, vector<Label> *labels) const {
71     labels->clear();
72     if (token_type_ == BYTE) {
73       for (size_t i = 0; i < str.size(); ++i)
74         labels->push_back(static_cast<unsigned char>(str[i]));
75     } else if (token_type_ == UTF8) {
76       return UTF8StringToLabels(str, labels);
77     } else {
78       char *c_str = new char[str.size() + 1];
79       str.copy(c_str, str.size());
80       c_str[str.size()] = 0;
81       vector<char *> vec;
82       string separator = "\n" + FLAGS_fst_field_separator;
83       SplitToVector(c_str, separator.c_str(), &vec, true);
84       for (size_t i = 0; i < vec.size(); ++i) {
85         Label label;
86         if (!ConvertSymbolToLabel(vec[i], &label))
87           return false;
88         labels->push_back(label);
89       }
90       delete[] c_str;
91     }
92     return true;
93   }
94 
95   void Compile(const vector<Label> &labels, MutableFst<A> *fst,
96                const Weight &weight = Weight::One()) const {
97     fst->DeleteStates();
98     while (fst->NumStates() <= labels.size())
99       fst->AddState();
100     for (size_t i = 0; i < labels.size(); ++i)
101       fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
102     fst->SetStart(0);
103     fst->SetFinal(labels.size(), weight);
104   }
105 
106   template <class Unsigned>
Compile(const vector<Label> & labels,CompactFst<A,StringCompactor<A>,Unsigned> * fst)107   void Compile(const vector<Label> &labels,
108                CompactFst<A, StringCompactor<A>, Unsigned> *fst) const {
109     fst->SetCompactElements(labels.begin(), labels.end());
110   }
111 
112   template <class Unsigned>
113   void Compile(const vector<Label> &labels,
114                CompactFst<A, WeightedStringCompactor<A>, Unsigned> *fst,
115                const Weight &weight = Weight::One()) const {
116     vector<pair<Label, Weight> > compacts;
117     compacts.reserve(labels.size());
118     for (size_t i = 0; i < labels.size(); ++i)
119       compacts.push_back(make_pair(labels[i], Weight::One()));
120     compacts.back().second = weight;
121     fst->SetCompactElements(compacts.begin(), compacts.end());
122   }
123 
ConvertSymbolToLabel(const char * s,Label * output)124   bool ConvertSymbolToLabel(const char *s, Label* output) const {
125     int64 n;
126     if (syms_) {
127       n = syms_->Find(s);
128       if ((n == -1) && (unknown_label_ != kNoLabel))
129         n = unknown_label_;
130       if (n == -1 || (!allow_negative_ && n < 0)) {
131         VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Symbol \"" << s
132                 << "\" is not mapped to any integer label, symbol table = "
133                  << syms_->Name();
134         return false;
135       }
136     } else {
137       char *p;
138       n = strtoll(s, &p, 10);
139       if (p < s + strlen(s) || (!allow_negative_ && n < 0)) {
140         VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Bad label integer "
141                 << "= \"" << s << "\"";
142         return false;
143       }
144     }
145     *output = n;
146     return true;
147   }
148 
149   TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
150   const SymbolTable *syms_;  // Symbol table used when token type is symbol
151   Label unknown_label_;      // Label for token missing from symbol table
152   bool allow_negative_;      // Negative labels allowed?
153 
154   DISALLOW_COPY_AND_ASSIGN(StringCompiler);
155 };
156 
157 // Functor to print a string FST as a string.
158 template <class A>
159 class StringPrinter {
160  public:
161   typedef A Arc;
162   typedef typename A::Label Label;
163   typedef typename A::StateId StateId;
164   typedef typename A::Weight Weight;
165 
166   enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
167 
168   StringPrinter(TokenType token_type,
169                 const SymbolTable *syms = 0)
token_type_(token_type)170       : token_type_(token_type), syms_(syms) {}
171 
172   // Convert the FST 'fst' into the string 'output'
operator()173   bool operator()(const Fst<A> &fst, string *output) {
174     bool is_a_string = FstToLabels(fst);
175     if (!is_a_string) {
176       VLOG(1) << "StringPrinter::operator(): Fst is not a string.";
177       return false;
178     }
179 
180     output->clear();
181 
182     if (token_type_ == SYMBOL) {
183       stringstream sstrm;
184       for (size_t i = 0; i < labels_.size(); ++i) {
185         if (i)
186           sstrm << *(FLAGS_fst_field_separator.rbegin());
187         if (!PrintLabel(labels_[i], sstrm))
188           return false;
189       }
190       *output = sstrm.str();
191     } else if (token_type_ == BYTE) {
192       output->reserve(labels_.size());
193       for (size_t i = 0; i < labels_.size(); ++i) {
194         output->push_back(labels_[i]);
195       }
196     } else if (token_type_ == UTF8) {
197       return LabelsToUTF8String(labels_, output);
198     } else {
199       VLOG(1) << "StringPrinter::operator(): Unknown token type: "
200               << token_type_;
201       return false;
202     }
203     return true;
204   }
205 
206  private:
FstToLabels(const Fst<A> & fst)207   bool FstToLabels(const Fst<A> &fst) {
208     labels_.clear();
209 
210     StateId s = fst.Start();
211     if (s == kNoStateId) {
212       VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
213               << "string fst.";
214       return false;
215     }
216 
217     while (fst.Final(s) == Weight::Zero()) {
218       ArcIterator<Fst<A> > aiter(fst, s);
219       if (aiter.Done()) {
220         VLOG(2) << "StringPrinter::FstToLabels: String fst traversal does "
221                 << "not reach final state.";
222         return false;
223       }
224 
225       const A& arc = aiter.Value();
226       labels_.push_back(arc.olabel);
227 
228       s = arc.nextstate;
229       if (s == kNoStateId) {
230         VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid "
231                 << "state.";
232         return false;
233       }
234 
235       aiter.Next();
236       if (!aiter.Done()) {
237         VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
238                 << "outgoing arcs found.";
239         return false;
240       }
241     }
242 
243     return true;
244   }
245 
PrintLabel(Label lab,ostream & ostrm)246   bool PrintLabel(Label lab, ostream& ostrm) {
247     if (syms_) {
248       string symbol = syms_->Find(lab);
249       if (symbol == "") {
250         VLOG(2) << "StringPrinter::PrintLabel: Integer " << lab << " is not "
251                 << "mapped to any textual symbol, symbol table = "
252                  << syms_->Name();
253         return false;
254       }
255       ostrm << symbol;
256     } else {
257       ostrm << lab;
258     }
259     return true;
260   }
261 
262   TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
263   const SymbolTable *syms_;  // Symbol table used when token type is symbol
264   vector<Label> labels_;     // Input FST labels.
265 
266   DISALLOW_COPY_AND_ASSIGN(StringPrinter);
267 };
268 
269 }  // namespace fst
270 
271 #endif // FST_LIB_STRING_H_
272