1
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2005-2010 Google, Inc.
15 // All Rights Reserved.
16 //
17 // Author : Johan Schalkwyk
18 //
19 // \file
20 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
21
22 #ifndef FST_LIB_SYMBOL_TABLE_H__
23 #define FST_LIB_SYMBOL_TABLE_H__
24
25 #include <cstring>
26 #include <string>
27 #include <utility>
28 using std::pair; using std::make_pair;
29 #include <vector>
30 using std::vector;
31
32
33 #include <fst/compat.h>
34 #include <iostream>
35 #include <fstream>
36 #include <sstream>
37
38
39 #include <map>
40
41 DECLARE_bool(fst_compat_symbols);
42
43 namespace fst {
44
45 // WARNING: Reading via symbol table read options should
46 // not be used. This is a temporary work around for
47 // reading symbol ranges of previously stored symbol sets.
48 struct SymbolTableReadOptions {
SymbolTableReadOptionsSymbolTableReadOptions49 SymbolTableReadOptions() { }
50
SymbolTableReadOptionsSymbolTableReadOptions51 SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_,
52 const string& source_)
53 : string_hash_ranges(string_hash_ranges_),
54 source(source_) { }
55
56 vector<pair<int64, int64> > string_hash_ranges;
57 string source;
58 };
59
60 struct SymbolTableTextOptions {
61 SymbolTableTextOptions();
62
63 bool allow_negative;
64 string fst_field_separator;
65 };
66
67 class SymbolTableImpl {
68 public:
SymbolTableImpl(const string & name)69 SymbolTableImpl(const string &name)
70 : name_(name),
71 available_key_(0),
72 dense_key_limit_(0),
73 check_sum_finalized_(false) {}
74
SymbolTableImpl(const SymbolTableImpl & impl)75 explicit SymbolTableImpl(const SymbolTableImpl& impl)
76 : name_(impl.name_),
77 available_key_(0),
78 dense_key_limit_(0),
79 check_sum_finalized_(false) {
80 for (size_t i = 0; i < impl.symbols_.size(); ++i) {
81 AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i]));
82 }
83 }
84
~SymbolTableImpl()85 ~SymbolTableImpl() {
86 for (size_t i = 0; i < symbols_.size(); ++i)
87 delete[] symbols_[i];
88 }
89
90 // TODO(johans): Add flag to specify whether the symbol
91 // should be indexed as string or int or both.
92 int64 AddSymbol(const string& symbol, int64 key);
93
AddSymbol(const string & symbol)94 int64 AddSymbol(const string& symbol) {
95 int64 key = Find(symbol);
96 return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
97 }
98
99 static SymbolTableImpl* ReadText(
100 istream &strm, const string &name,
101 const SymbolTableTextOptions &opts = SymbolTableTextOptions());
102
103 static SymbolTableImpl* Read(istream &strm,
104 const SymbolTableReadOptions& opts);
105
106 bool Write(ostream &strm) const;
107
108 //
109 // Return the string associated with the key. If the key is out of
110 // range (<0, >max), return an empty string.
Find(int64 key)111 string Find(int64 key) const {
112 if (key >=0 && key < dense_key_limit_)
113 return string(symbols_[key]);
114
115 map<int64, const char*>::const_iterator it =
116 key_map_.find(key);
117 if (it == key_map_.end()) {
118 return "";
119 }
120 return string(it->second);
121 }
122
123 //
124 // Return the key associated with the symbol. If the symbol
125 // does not exists, return SymbolTable::kNoSymbol.
Find(const string & symbol)126 int64 Find(const string& symbol) const {
127 return Find(symbol.c_str());
128 }
129
130 //
131 // Return the key associated with the symbol. If the symbol
132 // does not exists, return SymbolTable::kNoSymbol.
Find(const char * symbol)133 int64 Find(const char* symbol) const {
134 map<const char *, int64, StrCmp>::const_iterator it =
135 symbol_map_.find(symbol);
136 if (it == symbol_map_.end()) {
137 return -1;
138 }
139 return it->second;
140 }
141
GetNthKey(ssize_t pos)142 int64 GetNthKey(ssize_t pos) const {
143 if ((pos < 0) || (pos >= symbols_.size())) return -1;
144 else return Find(symbols_[pos]);
145 }
146
Name()147 const string& Name() const { return name_; }
148
IncrRefCount()149 int IncrRefCount() const {
150 return ref_count_.Incr();
151 }
DecrRefCount()152 int DecrRefCount() const {
153 return ref_count_.Decr();
154 }
RefCount()155 int RefCount() const {
156 return ref_count_.count();
157 }
158
CheckSum()159 string CheckSum() const {
160 MaybeRecomputeCheckSum();
161 return check_sum_string_;
162 }
163
LabeledCheckSum()164 string LabeledCheckSum() const {
165 MaybeRecomputeCheckSum();
166 return labeled_check_sum_string_;
167 }
168
AvailableKey()169 int64 AvailableKey() const {
170 return available_key_;
171 }
172
NumSymbols()173 size_t NumSymbols() const {
174 return symbols_.size();
175 }
176
177 private:
178 // Recomputes the checksums (both of them) if we've had changes since the last
179 // computation (i.e., if check_sum_finalized_ is false).
180 // Takes ~2.5 microseconds (dbg) or ~230 nanoseconds (opt) on a 2.67GHz Xeon
181 // if the checksum is up-to-date (requiring no recomputation).
182 void MaybeRecomputeCheckSum() const;
183
184 struct StrCmp {
operatorStrCmp185 bool operator()(const char *s1, const char *s2) const {
186 return strcmp(s1, s2) < 0;
187 }
188 };
189
190 string name_;
191 int64 available_key_;
192 int64 dense_key_limit_;
193 vector<const char *> symbols_;
194 map<int64, const char*> key_map_;
195 map<const char *, int64, StrCmp> symbol_map_;
196
197 mutable RefCounter ref_count_;
198 mutable bool check_sum_finalized_;
199 mutable string check_sum_string_;
200 mutable string labeled_check_sum_string_;
201 mutable Mutex check_sum_mutex_;
202 };
203
204 //
205 // \class SymbolTable
206 // \brief Symbol (string) to int and reverse mapping
207 //
208 // The SymbolTable implements the mappings of labels to strings and reverse.
209 // SymbolTables are used to describe the alphabet of the input and output
210 // labels for arcs in a Finite State Transducer.
211 //
212 // SymbolTables are reference counted and can therefore be shared across
213 // multiple machines. For example a language model grammar G, with a
214 // SymbolTable for the words in the language model can share this symbol
215 // table with the lexical representation L o G.
216 //
217 class SymbolTable {
218 public:
219 static const int64 kNoSymbol = -1;
220
221 // Construct symbol table with an unspecified name.
SymbolTable()222 SymbolTable() : impl_(new SymbolTableImpl("<unspecified>")) {}
223
224 // Construct symbol table with a unique name.
SymbolTable(const string & name)225 SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
226
227 // Create a reference counted copy.
SymbolTable(const SymbolTable & table)228 SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
229 impl_->IncrRefCount();
230 }
231
232 // Derefence implentation object. When reference count hits 0, delete
233 // implementation.
~SymbolTable()234 virtual ~SymbolTable() {
235 if (!impl_->DecrRefCount()) delete impl_;
236 }
237
238 // Copys the implemenation from one symbol table to another.
239 void operator=(const SymbolTable &st) {
240 if (impl_ != st.impl_) {
241 st.impl_->IncrRefCount();
242 if (!impl_->DecrRefCount()) delete impl_;
243 impl_ = st.impl_;
244 }
245 }
246
247 // Read an ascii representation of the symbol table from an istream. Pass a
248 // name to give the resulting SymbolTable.
249 static SymbolTable* ReadText(
250 istream &strm, const string& name,
251 const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
252 SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm, name, opts);
253 if (!impl)
254 return 0;
255 else
256 return new SymbolTable(impl);
257 }
258
259 // read an ascii representation of the symbol table
260 static SymbolTable* ReadText(const string& filename,
261 const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
262 ifstream strm(filename.c_str(), ifstream::in);
263 if (!strm) {
264 LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename;
265 return 0;
266 }
267 return ReadText(strm, filename, opts);
268 }
269
270
271 // WARNING: Reading via symbol table read options should
272 // not be used. This is a temporary work around.
Read(istream & strm,const SymbolTableReadOptions & opts)273 static SymbolTable* Read(istream &strm,
274 const SymbolTableReadOptions& opts) {
275 SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts);
276 if (!impl)
277 return 0;
278 else
279 return new SymbolTable(impl);
280 }
281
282 // read a binary dump of the symbol table from a stream
Read(istream & strm,const string & source)283 static SymbolTable* Read(istream &strm, const string& source) {
284 SymbolTableReadOptions opts;
285 opts.source = source;
286 return Read(strm, opts);
287 }
288
289 // read a binary dump of the symbol table
Read(const string & filename)290 static SymbolTable* Read(const string& filename) {
291 ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
292 if (!strm) {
293 LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
294 return 0;
295 }
296 return Read(strm, filename);
297 }
298
299 //--------------------------------------------------------
300 // Derivable Interface (final)
301 //--------------------------------------------------------
302 // create a reference counted copy
Copy()303 virtual SymbolTable* Copy() const {
304 return new SymbolTable(*this);
305 }
306
307 // Add a symbol with given key to table. A symbol table also
308 // keeps track of the last available key (highest key value in
309 // the symbol table).
AddSymbol(const string & symbol,int64 key)310 virtual int64 AddSymbol(const string& symbol, int64 key) {
311 MutateCheck();
312 return impl_->AddSymbol(symbol, key);
313 }
314
315 // Add a symbol to the table. The associated value key is automatically
316 // assigned by the symbol table.
AddSymbol(const string & symbol)317 virtual int64 AddSymbol(const string& symbol) {
318 MutateCheck();
319 return impl_->AddSymbol(symbol);
320 }
321
322 // Add another symbol table to this table. All key values will be offset
323 // by the current available key (highest key value in the symbol table).
324 // Note string symbols with the same key value with still have the same
325 // key value after the symbol table has been merged, but a different
326 // value. Adding symbol tables do not result in changes in the base table.
327 virtual void AddTable(const SymbolTable& table);
328
329 // return the name of the symbol table
Name()330 virtual const string& Name() const {
331 return impl_->Name();
332 }
333
334 // Return the label-agnostic MD5 check-sum for this table. All new symbols
335 // added to the table will result in an updated checksum.
336 // DEPRECATED.
CheckSum()337 virtual string CheckSum() const {
338 return impl_->CheckSum();
339 }
340
341 // Same as CheckSum(), but this returns an label-dependent version.
LabeledCheckSum()342 virtual string LabeledCheckSum() const {
343 return impl_->LabeledCheckSum();
344 }
345
Write(ostream & strm)346 virtual bool Write(ostream &strm) const {
347 return impl_->Write(strm);
348 }
349
Write(const string & filename)350 bool Write(const string& filename) const {
351 ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
352 if (!strm) {
353 LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
354 return false;
355 }
356 return Write(strm);
357 }
358
359 // Dump an ascii text representation of the symbol table via a stream
360 virtual bool WriteText(
361 ostream &strm,
362 const SymbolTableTextOptions &opts = SymbolTableTextOptions()) const;
363
364 // Dump an ascii text representation of the symbol table
WriteText(const string & filename)365 bool WriteText(const string& filename) const {
366 ofstream strm(filename.c_str());
367 if (!strm) {
368 LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
369 return false;
370 }
371 return WriteText(strm);
372 }
373
374 // Return the string associated with the key. If the key is out of
375 // range (<0, >max), log error and return an empty string.
Find(int64 key)376 virtual string Find(int64 key) const {
377 return impl_->Find(key);
378 }
379
380 // Return the key associated with the symbol. If the symbol
381 // does not exists, log error and return SymbolTable::kNoSymbol
Find(const string & symbol)382 virtual int64 Find(const string& symbol) const {
383 return impl_->Find(symbol);
384 }
385
386 // Return the key associated with the symbol. If the symbol
387 // does not exists, log error and return SymbolTable::kNoSymbol
Find(const char * symbol)388 virtual int64 Find(const char* symbol) const {
389 return impl_->Find(symbol);
390 }
391
392 // Return the current available key (i.e highest key number+1) in
393 // the symbol table
AvailableKey(void)394 virtual int64 AvailableKey(void) const {
395 return impl_->AvailableKey();
396 }
397
398 // Return the current number of symbols in table (not necessarily
399 // equal to AvailableKey())
NumSymbols(void)400 virtual size_t NumSymbols(void) const {
401 return impl_->NumSymbols();
402 }
403
GetNthKey(ssize_t pos)404 virtual int64 GetNthKey(ssize_t pos) const {
405 return impl_->GetNthKey(pos);
406 }
407
408 private:
SymbolTable(SymbolTableImpl * impl)409 explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
410
MutateCheck()411 void MutateCheck() {
412 // Copy on write
413 if (impl_->RefCount() > 1) {
414 impl_->DecrRefCount();
415 impl_ = new SymbolTableImpl(*impl_);
416 }
417 }
418
Impl()419 const SymbolTableImpl* Impl() const {
420 return impl_;
421 }
422
423 private:
424 SymbolTableImpl* impl_;
425 };
426
427
428 //
429 // \class SymbolTableIterator
430 // \brief Iterator class for symbols in a symbol table
431 class SymbolTableIterator {
432 public:
SymbolTableIterator(const SymbolTable & table)433 SymbolTableIterator(const SymbolTable& table)
434 : table_(table),
435 pos_(0),
436 nsymbols_(table.NumSymbols()),
437 key_(table.GetNthKey(0)) { }
438
~SymbolTableIterator()439 ~SymbolTableIterator() { }
440
441 // is iterator done
Done(void)442 bool Done(void) {
443 return (pos_ == nsymbols_);
444 }
445
446 // return the Value() of the current symbol (int64 key)
Value(void)447 int64 Value(void) {
448 return key_;
449 }
450
451 // return the string of the current symbol
Symbol(void)452 string Symbol(void) {
453 return table_.Find(key_);
454 }
455
456 // advance iterator forward
Next(void)457 void Next(void) {
458 ++pos_;
459 if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_);
460 }
461
462 // reset iterator
Reset(void)463 void Reset(void) {
464 pos_ = 0;
465 key_ = table_.GetNthKey(0);
466 }
467
468 private:
469 const SymbolTable& table_;
470 ssize_t pos_;
471 size_t nsymbols_;
472 int64 key_;
473 };
474
475
476 // Tests compatibilty between two sets of symbol tables
477 inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
478 bool warning = true) {
479 if (!FLAGS_fst_compat_symbols) {
480 return true;
481 } else if (!syms1 && !syms2) {
482 return true;
483 } else if (syms1 && !syms2) {
484 if (warning)
485 LOG(WARNING) <<
486 "CompatSymbols: first symbol table present but second missing";
487 return false;
488 } else if (!syms1 && syms2) {
489 if (warning)
490 LOG(WARNING) <<
491 "CompatSymbols: second symbol table present but first missing";
492 return false;
493 } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) {
494 if (warning)
495 LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match";
496 return false;
497 } else {
498 return true;
499 }
500 }
501
502
503 // Relabels a symbol table as specified by the input vector of pairs
504 // (old label, new label). The new symbol table only retains symbols
505 // for which a relabeling is *explicitely* specified.
506 // TODO(allauzen): consider adding options to allow for some form
507 // of implicit identity relabeling.
508 template <class Label>
RelabelSymbolTable(const SymbolTable * table,const vector<pair<Label,Label>> & pairs)509 SymbolTable *RelabelSymbolTable(const SymbolTable *table,
510 const vector<pair<Label, Label> > &pairs) {
511 SymbolTable *new_table = new SymbolTable(
512 table->Name().empty() ? string() :
513 (string("relabeled_") + table->Name()));
514
515 for (size_t i = 0; i < pairs.size(); ++i)
516 new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second);
517
518 return new_table;
519 }
520
521 // Symbol Table Serialization
SymbolTableToString(const SymbolTable * table,string * result)522 inline void SymbolTableToString(const SymbolTable *table, string *result) {
523 ostringstream ostrm;
524 table->Write(ostrm);
525 *result = ostrm.str();
526 }
527
StringToSymbolTable(const string & s)528 inline SymbolTable *StringToSymbolTable(const string &s) {
529 istringstream istrm(s);
530 return SymbolTable::Read(istrm, SymbolTableReadOptions());
531 }
532
533
534
535 } // namespace fst
536
537 #endif // FST_LIB_SYMBOL_TABLE_H__
538