1 
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2005-2010 Google, Inc.
15 // All Rights Reserved.
16 //
17 // Author : Johan Schalkwyk
18 //
19 // \file
20 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
21 
22 #include <fst/symbol-table.h>
23 
24 #include <fst/util.h>
25 
26 DEFINE_bool(fst_compat_symbols, true,
27             "Require symbol tables to match when appropriate");
28 DEFINE_string(fst_field_separator, "\t ",
29               "Set of characters used as a separator between printed fields");
30 
31 namespace fst {
32 
33 // Maximum line length in textual symbols file.
34 const int kLineLen = 8096;
35 
36 // Identifies stream data as a symbol table (and its endianity)
37 static const int32 kSymbolTableMagicNumber = 2125658996;
38 
SymbolTableTextOptions()39 SymbolTableTextOptions::SymbolTableTextOptions()
40     : allow_negative(false), fst_field_separator(FLAGS_fst_field_separator) { }
41 
ReadText(istream & strm,const string & filename,const SymbolTableTextOptions & opts)42 SymbolTableImpl* SymbolTableImpl::ReadText(istream &strm,
43                                            const string &filename,
44                                            const SymbolTableTextOptions &opts) {
45   SymbolTableImpl* impl = new SymbolTableImpl(filename);
46 
47   int64 nline = 0;
48   char line[kLineLen];
49   while (strm.getline(line, kLineLen)) {
50     ++nline;
51     vector<char *> col;
52     string separator = opts.fst_field_separator + "\n";
53     SplitToVector(line, separator.c_str(), &col, true);
54     if (col.size() == 0)  // empty line
55       continue;
56     if (col.size() != 2) {
57       LOG(ERROR) << "SymbolTable::ReadText: Bad number of columns ("
58                  << col.size() << "), "
59                  << "file = " << filename << ", line = " << nline
60                  << ":<" << line << ">";
61       delete impl;
62       return 0;
63     }
64     const char *symbol = col[0];
65     const char *value = col[1];
66     char *p;
67     int64 key = strtoll(value, &p, 10);
68     if (p < value + strlen(value) ||
69         (!opts.allow_negative && key < 0) || key == -1) {
70       LOG(ERROR) << "SymbolTable::ReadText: Bad non-negative integer \""
71                  << value << "\", "
72                  << "file = " << filename << ", line = " << nline;
73       delete impl;
74       return 0;
75     }
76     impl->AddSymbol(symbol, key);
77   }
78 
79   return impl;
80 }
81 
MaybeRecomputeCheckSum() const82 void SymbolTableImpl::MaybeRecomputeCheckSum() const {
83   {
84     ReaderMutexLock check_sum_lock(&check_sum_mutex_);
85     if (check_sum_finalized_)
86       return;
87   }
88 
89   // We'll aquire an exclusive lock to recompute the checksums.
90   MutexLock check_sum_lock(&check_sum_mutex_);
91   if (check_sum_finalized_)  // Another thread (coming in around the same time
92     return;                  // might have done it already).  So we recheck.
93 
94   // Calculate the original label-agnostic check sum.
95   CheckSummer check_sum;
96   for (int64 i = 0; i < symbols_.size(); ++i)
97     check_sum.Update(symbols_[i], strlen(symbols_[i]) + 1);
98   check_sum_string_ = check_sum.Digest();
99 
100   // Calculate the safer, label-dependent check sum.
101   CheckSummer labeled_check_sum;
102   for (int64 key = 0; key < dense_key_limit_; ++key) {
103     ostringstream line;
104     line << symbols_[key] << '\t' << key;
105     labeled_check_sum.Update(line.str().data(), line.str().size());
106   }
107   for (map<int64, const char*>::const_iterator it =
108        key_map_.begin();
109        it != key_map_.end();
110        ++it) {
111     if (it->first >= dense_key_limit_) {
112       ostringstream line;
113       line << it->second << '\t' << it->first;
114       labeled_check_sum.Update(line.str().data(), line.str().size());
115     }
116   }
117   labeled_check_sum_string_ = labeled_check_sum.Digest();
118 
119   check_sum_finalized_ = true;
120 }
121 
AddSymbol(const string & symbol,int64 key)122 int64 SymbolTableImpl::AddSymbol(const string& symbol, int64 key) {
123   map<const char *, int64, StrCmp>::const_iterator it =
124       symbol_map_.find(symbol.c_str());
125   if (it == symbol_map_.end()) {  // only add if not in table
126     check_sum_finalized_ = false;
127 
128     char *csymbol = new char[symbol.size() + 1];
129     strcpy(csymbol, symbol.c_str());
130     symbols_.push_back(csymbol);
131     key_map_[key] = csymbol;
132     symbol_map_[csymbol] = key;
133 
134     if (key >= available_key_) {
135       available_key_ = key + 1;
136     }
137   } else {
138     // Log if symbol already in table with different key
139     if (it->second != key) {
140       VLOG(1) << "SymbolTable::AddSymbol: symbol = " << symbol
141               << " already in symbol_map_ with key = "
142               << it->second
143               << " but supplied new key = " << key
144               << " (ignoring new key)";
145     }
146   }
147   return key;
148 }
149 
IsInRange(const vector<pair<int64,int64>> & ranges,int64 key)150 static bool IsInRange(const vector<pair<int64, int64> >& ranges,
151                       int64 key) {
152   if (ranges.size() == 0) return true;
153   for (size_t i = 0; i < ranges.size(); ++i) {
154     if (key >= ranges[i].first && key <= ranges[i].second)
155       return true;
156   }
157   return false;
158 }
159 
Read(istream & strm,const SymbolTableReadOptions & opts)160 SymbolTableImpl* SymbolTableImpl::Read(istream &strm,
161                                        const SymbolTableReadOptions& opts) {
162   int32 magic_number = 0;
163   ReadType(strm, &magic_number);
164   if (!strm) {
165     LOG(ERROR) << "SymbolTable::Read: read failed";
166     return 0;
167   }
168   string name;
169   ReadType(strm, &name);
170   SymbolTableImpl* impl = new SymbolTableImpl(name);
171   ReadType(strm, &impl->available_key_);
172   int64 size;
173   ReadType(strm, &size);
174   if (!strm) {
175     LOG(ERROR) << "SymbolTable::Read: read failed";
176     delete impl;
177     return 0;
178   }
179 
180   string symbol;
181   int64 key;
182   impl->check_sum_finalized_ = false;
183   for (size_t i = 0; i < size; ++i) {
184     ReadType(strm, &symbol);
185     ReadType(strm, &key);
186     if (!strm) {
187       LOG(ERROR) << "SymbolTable::Read: read failed";
188       delete impl;
189       return 0;
190     }
191 
192     char *csymbol = new char[symbol.size() + 1];
193     strcpy(csymbol, symbol.c_str());
194     impl->symbols_.push_back(csymbol);
195     if (key == impl->dense_key_limit_ &&
196         key == impl->symbols_.size() - 1)
197       impl->dense_key_limit_ = impl->symbols_.size();
198     else
199       impl->key_map_[key] = csymbol;
200 
201     if (IsInRange(opts.string_hash_ranges, key)) {
202       impl->symbol_map_[csymbol] = key;
203     }
204   }
205   return impl;
206 }
207 
Write(ostream & strm) const208 bool SymbolTableImpl::Write(ostream &strm) const {
209   WriteType(strm, kSymbolTableMagicNumber);
210   WriteType(strm, name_);
211   WriteType(strm, available_key_);
212   int64 size = symbols_.size();
213   WriteType(strm, size);
214   // first write out dense keys
215   int64 i = 0;
216   for (; i < dense_key_limit_; ++i) {
217     WriteType(strm, string(symbols_[i]));
218     WriteType(strm, i);
219   }
220   // next write out the remaining non densely packed keys
221   for (map<const char *, int64, StrCmp>::const_iterator it =
222            symbol_map_.begin(); it != symbol_map_.end(); ++it) {
223     if ((it->second >= 0) && (it->second < dense_key_limit_))
224       continue;
225     WriteType(strm, string(it->first));
226     WriteType(strm, it->second);
227     ++i;
228   }
229   if (i != size) {
230     LOG(ERROR) << "SymbolTable::Write:  write failed";
231     return false;
232   }
233   strm.flush();
234   if (!strm) {
235     LOG(ERROR) << "SymbolTable::Write: write failed";
236     return false;
237   }
238   return true;
239 }
240 
241 const int64 SymbolTable::kNoSymbol;
242 
243 
AddTable(const SymbolTable & table)244 void SymbolTable::AddTable(const SymbolTable& table) {
245   for (SymbolTableIterator iter(table); !iter.Done(); iter.Next())
246     impl_->AddSymbol(iter.Symbol());
247 }
248 
WriteText(ostream & strm,const SymbolTableTextOptions & opts) const249 bool SymbolTable::WriteText(ostream &strm,
250                             const SymbolTableTextOptions &opts) const {
251   if (opts.fst_field_separator.empty()) {
252     LOG(ERROR) << "Missing required field separator";
253     return false;
254   }
255   bool once_only = false;
256   for (SymbolTableIterator iter(*this); !iter.Done(); iter.Next()) {
257     ostringstream line;
258     if (iter.Value() < 0 && !opts.allow_negative && !once_only) {
259       LOG(WARNING) << "Negative symbol table entry when not allowed";
260       once_only = true;
261     }
262     line << iter.Symbol() << opts.fst_field_separator[0] << iter.Value()
263          << '\n';
264     strm.write(line.str().data(), line.str().length());
265   }
266   return true;
267 }
268 }  // namespace fst
269