1 #include <algorithm>
2 #include <functional>
3 #include <queue>
4 #include <stdexcept>
5 
6 #include "range.h"
7 #include "trie.h"
8 
9 namespace marisa {
10 
build(const char * const * keys,std::size_t num_keys,const std::size_t * key_lengths,const double * key_weights,UInt32 * key_ids,int flags)11 void Trie::build(const char * const *keys, std::size_t num_keys,
12     const std::size_t *key_lengths, const double *key_weights,
13     UInt32 *key_ids, int flags) {
14   MARISA_THROW_IF((keys == NULL) && (num_keys != 0), MARISA_PARAM_ERROR);
15   Vector<Key<String> > temp_keys;
16   temp_keys.resize(num_keys);
17   for (std::size_t i = 0; i < temp_keys.size(); ++i) {
18     MARISA_THROW_IF(keys[i] == NULL, MARISA_PARAM_ERROR);
19     std::size_t length = 0;
20     if (key_lengths == NULL) {
21       while (keys[i][length] != '\0') {
22         ++length;
23       }
24     } else {
25       length = key_lengths[i];
26     }
27     MARISA_THROW_IF(length > MARISA_MAX_LENGTH, MARISA_SIZE_ERROR);
28     temp_keys[i].set_str(String(keys[i], length));
29     temp_keys[i].set_weight((key_weights != NULL) ? key_weights[i] : 1.0);
30   }
31   build_trie(temp_keys, key_ids, flags);
32 }
33 
build(const std::vector<std::string> & keys,std::vector<UInt32> * key_ids,int flags)34 void Trie::build(const std::vector<std::string> &keys,
35     std::vector<UInt32> *key_ids, int flags) {
36   Vector<Key<String> > temp_keys;
37   temp_keys.resize(keys.size());
38   for (std::size_t i = 0; i < temp_keys.size(); ++i) {
39     MARISA_THROW_IF(keys[i].length() > MARISA_MAX_LENGTH, MARISA_SIZE_ERROR);
40     temp_keys[i].set_str(String(keys[i].c_str(), keys[i].length()));
41     temp_keys[i].set_weight(1.0);
42   }
43   build_trie(temp_keys, key_ids, flags);
44 }
45 
build(const std::vector<std::pair<std::string,double>> & keys,std::vector<UInt32> * key_ids,int flags)46 void Trie::build(const std::vector<std::pair<std::string, double> > &keys,
47     std::vector<UInt32> *key_ids, int flags) {
48   Vector<Key<String> > temp_keys;
49   temp_keys.resize(keys.size());
50   for (std::size_t i = 0; i < temp_keys.size(); ++i) {
51     MARISA_THROW_IF(keys[i].first.length() > MARISA_MAX_LENGTH,
52         MARISA_SIZE_ERROR);
53     temp_keys[i].set_str(String(
54         keys[i].first.c_str(), keys[i].first.length()));
55     temp_keys[i].set_weight(keys[i].second);
56   }
57   build_trie(temp_keys, key_ids, flags);
58 }
59 
build_trie(Vector<Key<String>> & keys,std::vector<UInt32> * key_ids,int flags)60 void Trie::build_trie(Vector<Key<String> > &keys,
61     std::vector<UInt32> *key_ids, int flags) {
62   if (key_ids == NULL) {
63     build_trie(keys, static_cast<UInt32 *>(NULL), flags);
64     return;
65   }
66   std::vector<UInt32> temp_key_ids(keys.size());
67   build_trie(keys, temp_key_ids.empty() ? NULL : &temp_key_ids[0], flags);
68   key_ids->swap(temp_key_ids);
69 }
70 
build_trie(Vector<Key<String>> & keys,UInt32 * key_ids,int flags)71 void Trie::build_trie(Vector<Key<String> > &keys,
72     UInt32 *key_ids, int flags) {
73   Trie temp;
74   Vector<UInt32> terminals;
75   Progress progress(flags);
76   MARISA_THROW_IF(!progress.is_valid(), MARISA_PARAM_ERROR);
77   temp.build_trie(keys, &terminals, progress);
78 
79   typedef std::pair<UInt32, UInt32> TerminalIdPair;
80   Vector<TerminalIdPair> pairs;
81   pairs.resize(terminals.size());
82   for (UInt32 i = 0; i < pairs.size(); ++i) {
83     pairs[i].first = terminals[i];
84     pairs[i].second = i;
85   }
86   terminals.clear();
87   std::sort(pairs.begin(), pairs.end());
88 
89   UInt32 node = 0;
90   for (UInt32 i = 0; i < pairs.size(); ++i) {
91     while (node < pairs[i].first) {
92       temp.terminal_flags_.push_back(false);
93       ++node;
94     }
95     if (node == pairs[i].first) {
96       temp.terminal_flags_.push_back(true);
97       ++node;
98     }
99   }
100   while (node < temp.labels_.size()) {
101     temp.terminal_flags_.push_back(false);
102     ++node;
103   }
104   terminal_flags_.push_back(false);
105   temp.terminal_flags_.build();
106   temp.terminal_flags_.clear_select0s();
107   progress.test_total_size(temp.terminal_flags_.total_size());
108 
109   if (key_ids != NULL) {
110     for (UInt32 i = 0; i < pairs.size(); ++i) {
111       key_ids[pairs[i].second] = temp.node_to_key_id(pairs[i].first);
112     }
113   }
114   MARISA_THROW_IF(progress.total_size() != temp.total_size(),
115       MARISA_UNEXPECTED_ERROR);
116   temp.swap(this);
117 }
118 
119 template <typename T>
build_trie(Vector<Key<T>> & keys,Vector<UInt32> * terminals,Progress & progress)120 void Trie::build_trie(Vector<Key<T> > &keys,
121     Vector<UInt32> *terminals, Progress &progress) {
122   build_cur(keys, terminals, progress);
123   progress.test_total_size(louds_.total_size());
124   progress.test_total_size(sizeof(num_first_branches_));
125   progress.test_total_size(sizeof(num_keys_));
126   if (link_flags_.empty()) {
127     labels_.shrink();
128     progress.test_total_size(labels_.total_size());
129     progress.test_total_size(link_flags_.total_size());
130     progress.test_total_size(links_.total_size());
131     progress.test_total_size(tail_.total_size());
132     return;
133   }
134 
135   Vector<UInt32> next_terminals;
136   build_next(keys, &next_terminals, progress);
137 
138   if (has_trie()) {
139     progress.test_total_size(trie_->terminal_flags_.total_size());
140   } else if (tail_.mode() == MARISA_BINARY_TAIL) {
141     labels_.push_back('\0');
142     link_flags_.push_back(true);
143   }
144   link_flags_.build();
145 
146   for (UInt32 i = 0; i < next_terminals.size(); ++i) {
147     labels_[link_flags_.select1(i)] = (UInt8)(next_terminals[i] % 256);
148     next_terminals[i] /= 256;
149   }
150   link_flags_.clear_select0s();
151   if (has_trie() || (tail_.mode() == MARISA_TEXT_TAIL)) {
152     link_flags_.clear_select1s();
153   }
154 
155   links_.build(next_terminals);
156   labels_.shrink();
157   progress.test_total_size(labels_.total_size());
158   progress.test_total_size(link_flags_.total_size());
159   progress.test_total_size(links_.total_size());
160   progress.test_total_size(tail_.total_size());
161 }
162 
163 template <typename T>
build_cur(Vector<Key<T>> & keys,Vector<UInt32> * terminals,Progress & progress)164 void Trie::build_cur(Vector<Key<T> > &keys,
165     Vector<UInt32> *terminals, Progress &progress) {
166   num_keys_ = sort_keys(keys);
167   louds_.push_back(true);
168   louds_.push_back(false);
169   labels_.push_back('\0');
170   link_flags_.push_back(false);
171 
172   Vector<Key<T> > rest_keys;
173   std::queue<Range> queue;
174   Vector<WRange> wranges;
175   queue.push(Range(0, (UInt32)keys.size(), 0));
176   while (!queue.empty()) {
177     const UInt32 node = (UInt32)(link_flags_.size() - queue.size());
178     Range range = queue.front();
179     queue.pop();
180 
181     while ((range.begin() < range.end()) &&
182         (keys[range.begin()].str().length() == range.pos())) {
183       keys[range.begin()].set_terminal(node);
184       range.set_begin(range.begin() + 1);
185     }
186     if (range.begin() == range.end()) {
187       louds_.push_back(false);
188       continue;
189     }
190 
191     wranges.clear();
192     double weight = keys[range.begin()].weight();
193     for (UInt32 i = range.begin() + 1; i < range.end(); ++i) {
194       if (keys[i - 1].str()[range.pos()] != keys[i].str()[range.pos()]) {
195         wranges.push_back(WRange(range.begin(), i, range.pos(), weight));
196         range.set_begin(i);
197         weight = 0.0;
198       }
199       weight += keys[i].weight();
200     }
201     wranges.push_back(WRange(range, weight));
202     if (progress.order() == MARISA_WEIGHT_ORDER) {
203       std::stable_sort(wranges.begin(), wranges.end(), std::greater<WRange>());
204     }
205     if (node == 0) {
206       num_first_branches_ = wranges.size();
207     }
208     for (UInt32 i = 0; i < wranges.size(); ++i) {
209       const WRange &wrange = wranges[i];
210       UInt32 pos = wrange.pos() + 1;
211       if ((progress.tail() != MARISA_WITHOUT_TAIL) || !progress.is_last()) {
212         while (pos < keys[wrange.begin()].str().length()) {
213           UInt32 j;
214           for (j = wrange.begin() + 1; j < wrange.end(); ++j) {
215             if (keys[j - 1].str()[pos] != keys[j].str()[pos]) {
216               break;
217             }
218           }
219           if (j < wrange.end()) {
220             break;
221           }
222           ++pos;
223         }
224       }
225       if ((progress.trie() != MARISA_PATRICIA_TRIE) &&
226           (pos != keys[wrange.end() - 1].str().length())) {
227         pos = wrange.pos() + 1;
228       }
229       louds_.push_back(true);
230       if (pos == wrange.pos() + 1) {
231         labels_.push_back(keys[wrange.begin()].str()[wrange.pos()]);
232         link_flags_.push_back(false);
233       } else {
234         labels_.push_back('\0');
235         link_flags_.push_back(true);
236         Key<T> rest_key;
237         rest_key.set_str(keys[wrange.begin()].str().substr(
238             wrange.pos(), pos - wrange.pos()));
239         rest_key.set_weight(wrange.weight());
240         rest_keys.push_back(rest_key);
241       }
242       wranges[i].set_pos(pos);
243       queue.push(wranges[i].range());
244     }
245     louds_.push_back(false);
246   }
247   louds_.push_back(false);
248   louds_.build();
249   if (progress.trie_id() != 0) {
250     louds_.clear_select0s();
251   }
252   if (rest_keys.empty()) {
253     link_flags_.clear();
254   }
255 
256   build_terminals(keys, terminals);
257   keys.swap(&rest_keys);
258 }
259 
build_next(Vector<Key<String>> & keys,Vector<UInt32> * terminals,Progress & progress)260 void Trie::build_next(Vector<Key<String> > &keys,
261     Vector<UInt32> *terminals, Progress &progress) {
262   if (progress.is_last()) {
263     Vector<String> strs;
264     strs.resize(keys.size());
265     for (UInt32 i = 0; i < strs.size(); ++i) {
266       strs[i] = keys[i].str();
267     }
268     tail_.build(strs, terminals, progress.tail());
269     return;
270   }
271   Vector<Key<RString> > rkeys;
272   rkeys.resize(keys.size());
273   for (UInt32 i = 0; i < rkeys.size(); ++i) {
274     rkeys[i].set_str(RString(keys[i].str()));
275     rkeys[i].set_weight(keys[i].weight());
276   }
277   keys.clear();
278   trie_.reset(new (std::nothrow) Trie);
279   MARISA_THROW_IF(!has_trie(), MARISA_MEMORY_ERROR);
280   trie_->build_trie(rkeys, terminals, ++progress);
281 }
282 
build_next(Vector<Key<RString>> & rkeys,Vector<UInt32> * terminals,Progress & progress)283 void Trie::build_next(Vector<Key<RString> > &rkeys,
284     Vector<UInt32> *terminals, Progress &progress) {
285   if (progress.is_last()) {
286     Vector<String> strs;
287     strs.resize(rkeys.size());
288     for (UInt32 i = 0; i < strs.size(); ++i) {
289       strs[i] = String(rkeys[i].str().ptr(), rkeys[i].str().length());
290     }
291     tail_.build(strs, terminals, progress.tail());
292     return;
293   }
294   trie_.reset(new (std::nothrow) Trie);
295   MARISA_THROW_IF(!has_trie(), MARISA_MEMORY_ERROR);
296   trie_->build_trie(rkeys, terminals, ++progress);
297 }
298 
299 template <typename T>
sort_keys(Vector<Key<T>> & keys) const300 UInt32 Trie::sort_keys(Vector<Key<T> > &keys) const {
301   if (keys.empty()) {
302     return 0;
303   }
304   for (UInt32 i = 0; i < keys.size(); ++i) {
305     keys[i].set_id(i);
306   }
307   std::sort(keys.begin(), keys.end());
308   UInt32 count = 1;
309   for (UInt32 i = 1; i < keys.size(); ++i) {
310     if (keys[i - 1].str() != keys[i].str()) {
311       ++count;
312     }
313   }
314   return count;
315 }
316 
317 template <typename T>
build_terminals(const Vector<Key<T>> & keys,Vector<UInt32> * terminals) const318 void Trie::build_terminals(const Vector<Key<T> > &keys,
319     Vector<UInt32> *terminals) const {
320   Vector<UInt32> temp_terminals;
321   temp_terminals.resize(keys.size());
322   for (UInt32 i = 0; i < keys.size(); ++i) {
323     temp_terminals[keys[i].id()] = keys[i].terminal();
324   }
325   temp_terminals.swap(terminals);
326 }
327 
328 }  // namespace marisa
329