1 
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2005-2010 Google, Inc.
15 // All Rights Reserved.
16 //
17 // Author : Johan Schalkwyk
18 //
19 // \file
20 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
21 
22 #ifndef FST_LIB_SYMBOL_TABLE_H__
23 #define FST_LIB_SYMBOL_TABLE_H__
24 
25 #include <cstring>
26 #include <string>
27 #include <utility>
28 using std::pair; using std::make_pair;
29 #include <vector>
30 using std::vector;
31 
32 
33 #include <fst/compat.h>
34 #include <iostream>
35 #include <fstream>
36 #include <sstream>
37 
38 
39 #include <map>
40 
41 DECLARE_bool(fst_compat_symbols);
42 
43 namespace fst {
44 
45 // WARNING: Reading via symbol table read options should
46 //          not be used. This is a temporary work around for
47 //          reading symbol ranges of previously stored symbol sets.
48 struct SymbolTableReadOptions {
SymbolTableReadOptionsSymbolTableReadOptions49   SymbolTableReadOptions() { }
50 
SymbolTableReadOptionsSymbolTableReadOptions51   SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_,
52                          const string& source_)
53       : string_hash_ranges(string_hash_ranges_),
54         source(source_) { }
55 
56   vector<pair<int64, int64> > string_hash_ranges;
57   string source;
58 };
59 
60 struct SymbolTableTextOptions {
61   SymbolTableTextOptions();
62 
63   bool allow_negative;
64   string fst_field_separator;
65 };
66 
67 class SymbolTableImpl {
68  public:
SymbolTableImpl(const string & name)69   SymbolTableImpl(const string &name)
70       : name_(name),
71         available_key_(0),
72         dense_key_limit_(0),
73         check_sum_finalized_(false) {}
74 
SymbolTableImpl(const SymbolTableImpl & impl)75   explicit SymbolTableImpl(const SymbolTableImpl& impl)
76       : name_(impl.name_),
77         available_key_(0),
78         dense_key_limit_(0),
79         check_sum_finalized_(false) {
80     for (size_t i = 0; i < impl.symbols_.size(); ++i) {
81       AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i]));
82     }
83   }
84 
~SymbolTableImpl()85   ~SymbolTableImpl() {
86     for (size_t i = 0; i < symbols_.size(); ++i)
87       delete[] symbols_[i];
88   }
89 
90   // TODO(johans): Add flag to specify whether the symbol
91   //               should be indexed as string or int or both.
92   int64 AddSymbol(const string& symbol, int64 key);
93 
AddSymbol(const string & symbol)94   int64 AddSymbol(const string& symbol) {
95     int64 key = Find(symbol);
96     return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
97   }
98 
99   static SymbolTableImpl* ReadText(
100       istream &strm, const string &name,
101       const SymbolTableTextOptions &opts = SymbolTableTextOptions());
102 
103   static SymbolTableImpl* Read(istream &strm,
104                                const SymbolTableReadOptions& opts);
105 
106   bool Write(ostream &strm) const;
107 
108   //
109   // Return the string associated with the key. If the key is out of
110   // range (<0, >max), return an empty string.
Find(int64 key)111   string Find(int64 key) const {
112     if (key >=0 && key < dense_key_limit_)
113       return string(symbols_[key]);
114 
115     map<int64, const char*>::const_iterator it =
116         key_map_.find(key);
117     if (it == key_map_.end()) {
118       return "";
119     }
120     return string(it->second);
121   }
122 
123   //
124   // Return the key associated with the symbol. If the symbol
125   // does not exists, return SymbolTable::kNoSymbol.
Find(const string & symbol)126   int64 Find(const string& symbol) const {
127     return Find(symbol.c_str());
128   }
129 
130   //
131   // Return the key associated with the symbol. If the symbol
132   // does not exists, return SymbolTable::kNoSymbol.
Find(const char * symbol)133   int64 Find(const char* symbol) const {
134     map<const char *, int64, StrCmp>::const_iterator it =
135         symbol_map_.find(symbol);
136     if (it == symbol_map_.end()) {
137       return -1;
138     }
139     return it->second;
140   }
141 
GetNthKey(ssize_t pos)142   int64 GetNthKey(ssize_t pos) const {
143     if ((pos < 0) || (pos >= symbols_.size())) return -1;
144     else return Find(symbols_[pos]);
145   }
146 
Name()147   const string& Name() const { return name_; }
148 
IncrRefCount()149   int IncrRefCount() const {
150     return ref_count_.Incr();
151   }
DecrRefCount()152   int DecrRefCount() const {
153     return ref_count_.Decr();
154   }
RefCount()155   int RefCount() const {
156     return ref_count_.count();
157   }
158 
CheckSum()159   string CheckSum() const {
160     MaybeRecomputeCheckSum();
161     return check_sum_string_;
162   }
163 
LabeledCheckSum()164   string LabeledCheckSum() const {
165     MaybeRecomputeCheckSum();
166     return labeled_check_sum_string_;
167   }
168 
AvailableKey()169   int64 AvailableKey() const {
170     return available_key_;
171   }
172 
NumSymbols()173   size_t NumSymbols() const {
174     return symbols_.size();
175   }
176 
177  private:
178   // Recomputes the checksums (both of them) if we've had changes since the last
179   // computation (i.e., if check_sum_finalized_ is false).
180   // Takes ~2.5 microseconds (dbg) or ~230 nanoseconds (opt) on a 2.67GHz Xeon
181   // if the checksum is up-to-date (requiring no recomputation).
182   void MaybeRecomputeCheckSum() const;
183 
184   struct StrCmp {
operatorStrCmp185     bool operator()(const char *s1, const char *s2) const {
186       return strcmp(s1, s2) < 0;
187     }
188   };
189 
190   string name_;
191   int64 available_key_;
192   int64 dense_key_limit_;
193   vector<const char *> symbols_;
194   map<int64, const char*> key_map_;
195   map<const char *, int64, StrCmp> symbol_map_;
196 
197   mutable RefCounter ref_count_;
198   mutable bool check_sum_finalized_;
199   mutable string check_sum_string_;
200   mutable string labeled_check_sum_string_;
201   mutable Mutex check_sum_mutex_;
202 };
203 
204 //
205 // \class SymbolTable
206 // \brief Symbol (string) to int and reverse mapping
207 //
208 // The SymbolTable implements the mappings of labels to strings and reverse.
209 // SymbolTables are used to describe the alphabet of the input and output
210 // labels for arcs in a Finite State Transducer.
211 //
212 // SymbolTables are reference counted and can therefore be shared across
213 // multiple machines. For example a language model grammar G, with a
214 // SymbolTable for the words in the language model can share this symbol
215 // table with the lexical representation L o G.
216 //
217 class SymbolTable {
218  public:
219   static const int64 kNoSymbol = -1;
220 
221   // Construct symbol table with an unspecified name.
SymbolTable()222   SymbolTable() : impl_(new SymbolTableImpl("<unspecified>")) {}
223 
224   // Construct symbol table with a unique name.
SymbolTable(const string & name)225   SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
226 
227   // Create a reference counted copy.
SymbolTable(const SymbolTable & table)228   SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
229     impl_->IncrRefCount();
230   }
231 
232   // Derefence implentation object. When reference count hits 0, delete
233   // implementation.
~SymbolTable()234   virtual ~SymbolTable() {
235     if (!impl_->DecrRefCount()) delete impl_;
236   }
237 
238   // Copys the implemenation from one symbol table to another.
239   void operator=(const SymbolTable &st) {
240     if (impl_ != st.impl_) {
241       st.impl_->IncrRefCount();
242       if (!impl_->DecrRefCount()) delete impl_;
243       impl_ = st.impl_;
244     }
245   }
246 
247   // Read an ascii representation of the symbol table from an istream. Pass a
248   // name to give the resulting SymbolTable.
249   static SymbolTable* ReadText(
250       istream &strm, const string& name,
251       const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
252     SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm, name, opts);
253     if (!impl)
254       return 0;
255     else
256       return new SymbolTable(impl);
257   }
258 
259   // read an ascii representation of the symbol table
260   static SymbolTable* ReadText(const string& filename,
261       const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
262     ifstream strm(filename.c_str(), ifstream::in);
263     if (!strm) {
264       LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename;
265       return 0;
266     }
267     return ReadText(strm, filename, opts);
268   }
269 
270 
271   // WARNING: Reading via symbol table read options should
272   //          not be used. This is a temporary work around.
Read(istream & strm,const SymbolTableReadOptions & opts)273   static SymbolTable* Read(istream &strm,
274                            const SymbolTableReadOptions& opts) {
275     SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts);
276     if (!impl)
277       return 0;
278     else
279       return new SymbolTable(impl);
280   }
281 
282   // read a binary dump of the symbol table from a stream
Read(istream & strm,const string & source)283   static SymbolTable* Read(istream &strm, const string& source) {
284     SymbolTableReadOptions opts;
285     opts.source = source;
286     return Read(strm, opts);
287   }
288 
289   // read a binary dump of the symbol table
Read(const string & filename)290   static SymbolTable* Read(const string& filename) {
291     ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
292     if (!strm) {
293       LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
294       return 0;
295     }
296     return Read(strm, filename);
297   }
298 
299   //--------------------------------------------------------
300   // Derivable Interface (final)
301   //--------------------------------------------------------
302   // create a reference counted copy
Copy()303   virtual SymbolTable* Copy() const {
304     return new SymbolTable(*this);
305   }
306 
307   // Add a symbol with given key to table. A symbol table also
308   // keeps track of the last available key (highest key value in
309   // the symbol table).
AddSymbol(const string & symbol,int64 key)310   virtual int64 AddSymbol(const string& symbol, int64 key) {
311     MutateCheck();
312     return impl_->AddSymbol(symbol, key);
313   }
314 
315   // Add a symbol to the table. The associated value key is automatically
316   // assigned by the symbol table.
AddSymbol(const string & symbol)317   virtual int64 AddSymbol(const string& symbol) {
318     MutateCheck();
319     return impl_->AddSymbol(symbol);
320   }
321 
322   // Add another symbol table to this table. All key values will be offset
323   // by the current available key (highest key value in the symbol table).
324   // Note string symbols with the same key value with still have the same
325   // key value after the symbol table has been merged, but a different
326   // value. Adding symbol tables do not result in changes in the base table.
327   virtual void AddTable(const SymbolTable& table);
328 
329   // return the name of the symbol table
Name()330   virtual const string& Name() const {
331     return impl_->Name();
332   }
333 
334   // Return the label-agnostic MD5 check-sum for this table.  All new symbols
335   // added to the table will result in an updated checksum.
336   // DEPRECATED.
CheckSum()337   virtual string CheckSum() const {
338     return impl_->CheckSum();
339   }
340 
341   // Same as CheckSum(), but this returns an label-dependent version.
LabeledCheckSum()342   virtual string LabeledCheckSum() const {
343     return impl_->LabeledCheckSum();
344   }
345 
Write(ostream & strm)346   virtual bool Write(ostream &strm) const {
347     return impl_->Write(strm);
348   }
349 
Write(const string & filename)350   bool Write(const string& filename) const {
351     ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
352     if (!strm) {
353       LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
354       return false;
355     }
356     return Write(strm);
357   }
358 
359   // Dump an ascii text representation of the symbol table via a stream
360   virtual bool WriteText(
361       ostream &strm,
362       const SymbolTableTextOptions &opts = SymbolTableTextOptions()) const;
363 
364   // Dump an ascii text representation of the symbol table
WriteText(const string & filename)365   bool WriteText(const string& filename) const {
366     ofstream strm(filename.c_str());
367     if (!strm) {
368       LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
369       return false;
370     }
371     return WriteText(strm);
372   }
373 
374   // Return the string associated with the key. If the key is out of
375   // range (<0, >max), log error and return an empty string.
Find(int64 key)376   virtual string Find(int64 key) const {
377     return impl_->Find(key);
378   }
379 
380   // Return the key associated with the symbol. If the symbol
381   // does not exists, log error and  return SymbolTable::kNoSymbol
Find(const string & symbol)382   virtual int64 Find(const string& symbol) const {
383     return impl_->Find(symbol);
384   }
385 
386   // Return the key associated with the symbol. If the symbol
387   // does not exists, log error and  return SymbolTable::kNoSymbol
Find(const char * symbol)388   virtual int64 Find(const char* symbol) const {
389     return impl_->Find(symbol);
390   }
391 
392   // Return the current available key (i.e highest key number+1) in
393   // the symbol table
AvailableKey(void)394   virtual int64 AvailableKey(void) const {
395     return impl_->AvailableKey();
396   }
397 
398   // Return the current number of symbols in table (not necessarily
399   // equal to AvailableKey())
NumSymbols(void)400   virtual size_t NumSymbols(void) const {
401     return impl_->NumSymbols();
402   }
403 
GetNthKey(ssize_t pos)404   virtual int64 GetNthKey(ssize_t pos) const {
405     return impl_->GetNthKey(pos);
406   }
407 
408  private:
SymbolTable(SymbolTableImpl * impl)409   explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
410 
MutateCheck()411   void MutateCheck() {
412     // Copy on write
413     if (impl_->RefCount() > 1) {
414       impl_->DecrRefCount();
415       impl_ = new SymbolTableImpl(*impl_);
416     }
417   }
418 
Impl()419   const SymbolTableImpl* Impl() const {
420     return impl_;
421   }
422 
423  private:
424   SymbolTableImpl* impl_;
425 };
426 
427 
428 //
429 // \class SymbolTableIterator
430 // \brief Iterator class for symbols in a symbol table
431 class SymbolTableIterator {
432  public:
SymbolTableIterator(const SymbolTable & table)433   SymbolTableIterator(const SymbolTable& table)
434       : table_(table),
435         pos_(0),
436         nsymbols_(table.NumSymbols()),
437         key_(table.GetNthKey(0)) { }
438 
~SymbolTableIterator()439   ~SymbolTableIterator() { }
440 
441   // is iterator done
Done(void)442   bool Done(void) {
443     return (pos_ == nsymbols_);
444   }
445 
446   // return the Value() of the current symbol (int64 key)
Value(void)447   int64 Value(void) {
448     return key_;
449   }
450 
451   // return the string of the current symbol
Symbol(void)452   string Symbol(void) {
453     return table_.Find(key_);
454   }
455 
456   // advance iterator forward
Next(void)457   void Next(void) {
458     ++pos_;
459     if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_);
460   }
461 
462   // reset iterator
Reset(void)463   void Reset(void) {
464     pos_ = 0;
465     key_ = table_.GetNthKey(0);
466   }
467 
468  private:
469   const SymbolTable& table_;
470   ssize_t pos_;
471   size_t nsymbols_;
472   int64 key_;
473 };
474 
475 
476 // Tests compatibilty between two sets of symbol tables
477 inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
478                           bool warning = true) {
479   if (!FLAGS_fst_compat_symbols) {
480     return true;
481   } else if (!syms1 && !syms2) {
482     return true;
483   } else if (syms1 && !syms2) {
484     if (warning)
485       LOG(WARNING) <<
486           "CompatSymbols: first symbol table present but second missing";
487     return false;
488   } else if (!syms1 && syms2) {
489     if (warning)
490       LOG(WARNING) <<
491           "CompatSymbols: second symbol table present but first missing";
492     return false;
493   } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) {
494     if (warning)
495       LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match";
496     return false;
497   } else {
498     return true;
499   }
500 }
501 
502 
503 // Relabels a symbol table as specified by the input vector of pairs
504 // (old label, new label). The new symbol table only retains symbols
505 // for which a relabeling is *explicitely* specified.
506 // TODO(allauzen): consider adding options to allow for some form
507 // of implicit identity relabeling.
508 template <class Label>
RelabelSymbolTable(const SymbolTable * table,const vector<pair<Label,Label>> & pairs)509 SymbolTable *RelabelSymbolTable(const SymbolTable *table,
510                                 const vector<pair<Label, Label> > &pairs) {
511   SymbolTable *new_table = new SymbolTable(
512       table->Name().empty() ? string() :
513       (string("relabeled_") + table->Name()));
514 
515   for (size_t i = 0; i < pairs.size(); ++i)
516     new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second);
517 
518   return new_table;
519 }
520 
521 // Symbol Table Serialization
SymbolTableToString(const SymbolTable * table,string * result)522 inline void SymbolTableToString(const SymbolTable *table, string *result) {
523   ostringstream ostrm;
524   table->Write(ostrm);
525   *result = ostrm.str();
526 }
527 
StringToSymbolTable(const string & s)528 inline SymbolTable *StringToSymbolTable(const string &s) {
529   istringstream istrm(s);
530   return SymbolTable::Read(istrm, SymbolTableReadOptions());
531 }
532 
533 
534 
535 }  // namespace fst
536 
537 #endif  // FST_LIB_SYMBOL_TABLE_H__
538