1 #include <algorithm>
2 #include <stdexcept>
3 
4 #include "trie.h"
5 
6 namespace marisa {
7 namespace {
8 
9 template <typename T, typename U>
10 class PredictCallback {
11  public:
PredictCallback(T key_ids,U keys,std::size_t max_num_results)12   PredictCallback(T key_ids, U keys, std::size_t max_num_results)
13       : key_ids_(key_ids), keys_(keys),
14         max_num_results_(max_num_results), num_results_(0) {}
PredictCallback(const PredictCallback & callback)15   PredictCallback(const PredictCallback &callback)
16       : key_ids_(callback.key_ids_), keys_(callback.keys_),
17         max_num_results_(callback.max_num_results_),
18         num_results_(callback.num_results_) {}
19 
operator ()(marisa::UInt32 key_id,const std::string & key)20   bool operator()(marisa::UInt32 key_id, const std::string &key) {
21     if (key_ids_.is_valid()) {
22       key_ids_.insert(num_results_, key_id);
23     }
24     if (keys_.is_valid()) {
25       keys_.insert(num_results_, key);
26     }
27     return ++num_results_ < max_num_results_;
28   }
29 
30  private:
31   T key_ids_;
32   U keys_;
33   const std::size_t max_num_results_;
34   std::size_t num_results_;
35 
36   // Disallows assignment.
37   PredictCallback &operator=(const PredictCallback &);
38 };
39 
40 }  // namespace
41 
restore(UInt32 key_id) const42 std::string Trie::restore(UInt32 key_id) const {
43   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
44   MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
45   std::string key;
46   restore_(key_id, &key);
47   return key;
48 }
49 
restore(UInt32 key_id,std::string * key) const50 void Trie::restore(UInt32 key_id, std::string *key) const {
51   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
52   MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
53   MARISA_THROW_IF(key == NULL, MARISA_PARAM_ERROR);
54   restore_(key_id, key);
55 }
56 
restore(UInt32 key_id,char * key_buf,std::size_t key_buf_size) const57 std::size_t Trie::restore(UInt32 key_id, char *key_buf,
58     std::size_t key_buf_size) const {
59   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
60   MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
61   MARISA_THROW_IF((key_buf == NULL) && (key_buf_size != 0),
62       MARISA_PARAM_ERROR);
63   return restore_(key_id, key_buf, key_buf_size);
64 }
65 
lookup(const char * str) const66 UInt32 Trie::lookup(const char *str) const {
67   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
68   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
69   return lookup_<CQuery>(CQuery(str));
70 }
71 
lookup(const char * ptr,std::size_t length) const72 UInt32 Trie::lookup(const char *ptr, std::size_t length) const {
73   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
74   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
75   return lookup_<const Query &>(Query(ptr, length));
76 }
77 
find(const char * str,UInt32 * key_ids,std::size_t * key_lengths,std::size_t max_num_results) const78 std::size_t Trie::find(const char *str,
79     UInt32 *key_ids, std::size_t *key_lengths,
80     std::size_t max_num_results) const {
81   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
82   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
83   return find_<CQuery>(CQuery(str),
84       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
85 }
86 
find(const char * ptr,std::size_t length,UInt32 * key_ids,std::size_t * key_lengths,std::size_t max_num_results) const87 std::size_t Trie::find(const char *ptr, std::size_t length,
88     UInt32 *key_ids, std::size_t *key_lengths,
89     std::size_t max_num_results) const {
90   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
91   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
92   return find_<const Query &>(Query(ptr, length),
93       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
94 }
95 
find(const char * str,std::vector<UInt32> * key_ids,std::vector<std::size_t> * key_lengths,std::size_t max_num_results) const96 std::size_t Trie::find(const char *str,
97     std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
98     std::size_t max_num_results) const {
99   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
100   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
101   return find_<CQuery>(CQuery(str),
102       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
103 }
104 
find(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::size_t> * key_lengths,std::size_t max_num_results) const105 std::size_t Trie::find(const char *ptr, std::size_t length,
106     std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
107     std::size_t max_num_results) const {
108   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
109   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
110   return find_<const Query &>(Query(ptr, length),
111       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
112 }
113 
find_first(const char * str,std::size_t * key_length) const114 UInt32 Trie::find_first(const char *str,
115     std::size_t *key_length) const {
116   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
117   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
118   return find_first_<CQuery>(CQuery(str), key_length);
119 }
120 
find_first(const char * ptr,std::size_t length,std::size_t * key_length) const121 UInt32 Trie::find_first(const char *ptr, std::size_t length,
122     std::size_t *key_length) const {
123   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
124   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
125   return find_first_<const Query &>(Query(ptr, length), key_length);
126 }
127 
find_last(const char * str,std::size_t * key_length) const128 UInt32 Trie::find_last(const char *str,
129     std::size_t *key_length) const {
130   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
131   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
132   return find_last_<CQuery>(CQuery(str), key_length);
133 }
134 
find_last(const char * ptr,std::size_t length,std::size_t * key_length) const135 UInt32 Trie::find_last(const char *ptr, std::size_t length,
136     std::size_t *key_length) const {
137   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
138   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
139   return find_last_<const Query &>(Query(ptr, length), key_length);
140 }
141 
predict(const char * str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const142 std::size_t Trie::predict(const char *str,
143     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
144   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
145   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
146   return (keys == NULL) ?
147       predict_breadth_first(str, key_ids, keys, max_num_results) :
148       predict_depth_first(str, key_ids, keys, max_num_results);
149 }
150 
predict(const char * ptr,std::size_t length,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const151 std::size_t Trie::predict(const char *ptr, std::size_t length,
152     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
153   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
154   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
155   return (keys == NULL) ?
156       predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
157       predict_depth_first(ptr, length, key_ids, keys, max_num_results);
158 }
159 
predict(const char * str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const160 std::size_t Trie::predict(const char *str,
161     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
162     std::size_t max_num_results) const {
163   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
164   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
165   return (keys == NULL) ?
166       predict_breadth_first(str, key_ids, keys, max_num_results) :
167       predict_depth_first(str, key_ids, keys, max_num_results);
168 }
169 
predict(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const170 std::size_t Trie::predict(const char *ptr, std::size_t length,
171     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
172     std::size_t max_num_results) const {
173   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
174   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
175   return (keys == NULL) ?
176       predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
177       predict_depth_first(ptr, length, key_ids, keys, max_num_results);
178 }
179 
predict_breadth_first(const char * str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const180 std::size_t Trie::predict_breadth_first(const char *str,
181     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
182   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
183   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
184   return predict_breadth_first_<CQuery>(CQuery(str),
185       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
186 }
187 
predict_breadth_first(const char * ptr,std::size_t length,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const188 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
189     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
190   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
191   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
192   return predict_breadth_first_<const Query &>(Query(ptr, length),
193       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
194 }
195 
predict_breadth_first(const char * str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const196 std::size_t Trie::predict_breadth_first(const char *str,
197     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
198     std::size_t max_num_results) const {
199   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
200   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
201   return predict_breadth_first_<CQuery>(CQuery(str),
202       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
203 }
204 
predict_breadth_first(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const205 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
206     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
207     std::size_t max_num_results) const {
208   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
209   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
210   return predict_breadth_first_<const Query &>(Query(ptr, length),
211       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
212 }
213 
predict_depth_first(const char * str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const214 std::size_t Trie::predict_depth_first(const char *str,
215     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
216   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
217   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
218   return predict_depth_first_<CQuery>(CQuery(str),
219       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
220 }
221 
predict_depth_first(const char * ptr,std::size_t length,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const222 std::size_t Trie::predict_depth_first(const char *ptr, std::size_t length,
223     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
224   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
225   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
226   return predict_depth_first_<const Query &>(Query(ptr, length),
227       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
228 }
229 
predict_depth_first(const char * str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const230 std::size_t Trie::predict_depth_first(
231     const char *str, std::vector<UInt32> *key_ids,
232     std::vector<std::string> *keys, std::size_t max_num_results) const {
233   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
234   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
235   return predict_depth_first_<CQuery>(CQuery(str),
236       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
237 }
238 
predict_depth_first(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const239 std::size_t Trie::predict_depth_first(
240     const char *ptr, std::size_t length, std::vector<UInt32> *key_ids,
241     std::vector<std::string> *keys, std::size_t max_num_results) const {
242   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
243   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
244   return predict_depth_first_<const Query &>(Query(ptr, length),
245       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
246 }
247 
restore_(UInt32 key_id,std::string * key) const248 void Trie::restore_(UInt32 key_id, std::string *key) const {
249   const std::size_t start_pos = key->length();
250   UInt32 node = key_id_to_node(key_id);
251   while (node != 0) {
252     if (has_link(node)) {
253       const std::size_t prev_pos = key->length();
254       if (has_trie()) {
255         trie_->trie_restore(get_link(node), key);
256       } else {
257         tail_restore(node, key);
258       }
259       std::reverse(key->begin() + prev_pos, key->end());
260     } else {
261       *key += labels_[node];
262     }
263     node = get_parent(node);
264   }
265   std::reverse(key->begin() + start_pos, key->end());
266 }
267 
trie_restore(UInt32 node,std::string * key) const268 void Trie::trie_restore(UInt32 node, std::string *key) const {
269   do {
270     if (has_link(node)) {
271       if (has_trie()) {
272         trie_->trie_restore(get_link(node), key);
273       } else {
274         tail_restore(node, key);
275       }
276     } else {
277       *key += labels_[node];
278     }
279     node = get_parent(node);
280   } while (node != 0);
281 }
282 
tail_restore(UInt32 node,std::string * key) const283 void Trie::tail_restore(UInt32 node, std::string *key) const {
284   const UInt32 link_id = link_flags_.rank1(node);
285   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
286   if (tail_.mode() == MARISA_BINARY_TAIL) {
287     const UInt32 length = (links_[link_id + 1] * 256)
288         + labels_[link_flags_.select1(link_id + 1)] - offset;
289     key->append(reinterpret_cast<const char *>(tail_[offset]), length);
290   } else {
291     key->append(reinterpret_cast<const char *>(tail_[offset]));
292   }
293 }
294 
restore_(UInt32 key_id,char * key_buf,std::size_t key_buf_size) const295 std::size_t Trie::restore_(UInt32 key_id, char *key_buf,
296     std::size_t key_buf_size) const {
297   std::size_t pos = 0;
298   UInt32 node = key_id_to_node(key_id);
299   while (node != 0) {
300     if (has_link(node)) {
301       const std::size_t prev_pos = pos;
302       if (has_trie()) {
303         trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
304       } else {
305         tail_restore(node, key_buf, key_buf_size, pos);
306       }
307       if (pos < key_buf_size) {
308         std::reverse(key_buf + prev_pos, key_buf + pos);
309       }
310     } else {
311       if (pos < key_buf_size) {
312         key_buf[pos] = labels_[node];
313       }
314       ++pos;
315     }
316     node = get_parent(node);
317   }
318   if (pos < key_buf_size) {
319     key_buf[pos] = '\0';
320     std::reverse(key_buf, key_buf + pos);
321   }
322   return pos;
323 }
324 
trie_restore(UInt32 node,char * key_buf,std::size_t key_buf_size,std::size_t & pos) const325 void Trie::trie_restore(UInt32 node, char *key_buf,
326     std::size_t key_buf_size, std::size_t &pos) const {
327   do {
328     if (has_link(node)) {
329       if (has_trie()) {
330         trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
331       } else {
332         tail_restore(node, key_buf, key_buf_size, pos);
333       }
334     } else {
335       if (pos < key_buf_size) {
336         key_buf[pos] = labels_[node];
337       }
338       ++pos;
339     }
340     node = get_parent(node);
341   } while (node != 0);
342 }
343 
tail_restore(UInt32 node,char * key_buf,std::size_t key_buf_size,std::size_t & pos) const344 void Trie::tail_restore(UInt32 node, char *key_buf,
345     std::size_t key_buf_size, std::size_t &pos) const {
346   const UInt32 link_id = link_flags_.rank1(node);
347   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
348   if (tail_.mode() == MARISA_BINARY_TAIL) {
349     const UInt8 *ptr = tail_[offset];
350     const UInt32 length = (links_[link_id + 1] * 256)
351         + labels_[link_flags_.select1(link_id + 1)] - offset;
352     for (UInt32 i = 0; i < length; ++i) {
353       if (pos < key_buf_size) {
354         key_buf[pos] = ptr[i];
355       }
356       ++pos;
357     }
358   } else {
359     for (const UInt8 *str = tail_[offset]; *str != '\0'; ++str) {
360       if (pos < key_buf_size) {
361         key_buf[pos] = *str;
362       }
363       ++pos;
364     }
365   }
366 }
367 
368 template <typename T>
lookup_(T query) const369 UInt32 Trie::lookup_(T query) const {
370   UInt32 node = 0;
371   std::size_t pos = 0;
372   while (!query.ends_at(pos)) {
373     if (!find_child<T>(node, query, pos)) {
374       return notfound();
375     }
376   }
377   return terminal_flags_[node] ? node_to_key_id(node) : notfound();
378 }
379 
380 template <typename T>
trie_match(UInt32 node,T query,std::size_t pos) const381 std::size_t Trie::trie_match(UInt32 node, T query,
382     std::size_t pos) const {
383   if (has_link(node)) {
384     std::size_t next_pos;
385     if (has_trie()) {
386       next_pos = trie_->trie_match<T>(get_link(node), query, pos);
387     } else {
388       next_pos = tail_match<T>(node, get_link_id(node), query, pos);
389     }
390     if ((next_pos == mismatch()) || (next_pos == pos)) {
391       return next_pos;
392     }
393     pos = next_pos;
394   } else if (labels_[node] != query[pos]) {
395     return pos;
396   } else {
397     ++pos;
398   }
399   node = get_parent(node);
400   while (node != 0) {
401     if (query.ends_at(pos)) {
402       return mismatch();
403     }
404     if (has_link(node)) {
405       std::size_t next_pos;
406       if (has_trie()) {
407         next_pos = trie_->trie_match<T>(get_link(node), query, pos);
408       } else {
409         next_pos = tail_match<T>(node, get_link_id(node), query, pos);
410       }
411       if ((next_pos == mismatch()) || (next_pos == pos)) {
412         return mismatch();
413       }
414       pos = next_pos;
415     } else if (labels_[node] != query[pos]) {
416       return mismatch();
417     } else {
418       ++pos;
419     }
420     node = get_parent(node);
421   }
422   return pos;
423 }
424 
425 template std::size_t Trie::trie_match<CQuery>(UInt32 node,
426     CQuery query, std::size_t pos) const;
427 template std::size_t Trie::trie_match<const Query &>(UInt32 node,
428     const Query &query, std::size_t pos) const;
429 
430 template <typename T>
tail_match(UInt32 node,UInt32 link_id,T query,std::size_t pos) const431 std::size_t Trie::tail_match(UInt32 node, UInt32 link_id,
432     T query, std::size_t pos) const {
433   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
434   const UInt8 *ptr = tail_[offset];
435   if (*ptr != query[pos]) {
436     return pos;
437   } else if (tail_.mode() == MARISA_BINARY_TAIL) {
438     const UInt32 length = (links_[link_id + 1] * 256)
439         + labels_[link_flags_.select1(link_id + 1)] - offset;
440     for (UInt32 i = 1; i < length; ++i) {
441       if (query.ends_at(pos + i) || (ptr[i] != query[pos + i])) {
442         return mismatch();
443       }
444     }
445     return pos + length;
446   } else {
447     for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
448       if (query.ends_at(pos) || (*ptr != query[pos])) {
449         return mismatch();
450       }
451     }
452     return pos;
453   }
454 }
455 
456 template std::size_t Trie::tail_match<CQuery>(UInt32 node,
457     UInt32 link_id, CQuery query, std::size_t pos) const;
458 template std::size_t Trie::tail_match<const Query &>(UInt32 node,
459     UInt32 link_id, const Query &query, std::size_t pos) const;
460 
461 template <typename T, typename U, typename V>
find_(T query,U key_ids,V key_lengths,std::size_t max_num_results) const462 std::size_t Trie::find_(T query, U key_ids, V key_lengths,
463     std::size_t max_num_results) const {
464   if (max_num_results == 0) {
465     return 0;
466   }
467   std::size_t count = 0;
468   UInt32 node = 0;
469   std::size_t pos = 0;
470   do {
471     if (terminal_flags_[node]) {
472       if (key_ids.is_valid()) {
473         key_ids.insert(count, node_to_key_id(node));
474       }
475       if (key_lengths.is_valid()) {
476         key_lengths.insert(count, pos);
477       }
478       if (++count >= max_num_results) {
479         return count;
480       }
481     }
482   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
483   return count;
484 }
485 
486 template <typename T>
find_first_(T query,std::size_t * key_length) const487 UInt32 Trie::find_first_(T query, std::size_t *key_length) const {
488   UInt32 node = 0;
489   std::size_t pos = 0;
490   do {
491     if (terminal_flags_[node]) {
492       if (key_length != NULL) {
493         *key_length = pos;
494       }
495       return node_to_key_id(node);
496     }
497   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
498   return notfound();
499 }
500 
501 template <typename T>
find_last_(T query,std::size_t * key_length) const502 UInt32 Trie::find_last_(T query, std::size_t *key_length) const {
503   UInt32 node = 0;
504   UInt32 node_found = notfound();
505   std::size_t pos = 0;
506   std::size_t pos_found = mismatch();
507   do {
508     if (terminal_flags_[node]) {
509       node_found = node;
510       pos_found = pos;
511     }
512   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
513   if (node_found != notfound()) {
514     if (key_length != NULL) {
515       *key_length = pos_found;
516     }
517     return node_to_key_id(node_found);
518   }
519   return notfound();
520 }
521 
522 template <typename T, typename U, typename V>
predict_breadth_first_(T query,U key_ids,V keys,std::size_t max_num_results) const523 std::size_t Trie::predict_breadth_first_(T query, U key_ids, V keys,
524     std::size_t max_num_results) const {
525   if (max_num_results == 0) {
526     return 0;
527   }
528   UInt32 node = 0;
529   std::size_t pos = 0;
530   while (!query.ends_at(pos)) {
531     if (!predict_child<T>(node, query, pos, NULL)) {
532       return 0;
533     }
534   }
535   std::string key;
536   std::size_t count = 0;
537   if (terminal_flags_[node]) {
538     const UInt32 key_id = node_to_key_id(node);
539     if (key_ids.is_valid()) {
540       key_ids.insert(count, key_id);
541     }
542     if (keys.is_valid()) {
543       restore(key_id, &key);
544       keys.insert(count, key);
545     }
546     if (++count >= max_num_results) {
547       return count;
548     }
549   }
550   const UInt32 louds_pos = get_child(node);
551   if (!louds_[louds_pos]) {
552     return count;
553   }
554   UInt32 node_begin = louds_pos_to_node(louds_pos, node);
555   UInt32 node_end = louds_pos_to_node(get_child(node + 1), node + 1);
556   while (node_begin < node_end) {
557     const UInt32 key_id_begin = node_to_key_id(node_begin);
558     const UInt32 key_id_end = node_to_key_id(node_end);
559     if (key_ids.is_valid()) {
560       UInt32 temp_count = count;
561       for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
562         key_ids.insert(temp_count, key_id);
563         if (++temp_count >= max_num_results) {
564           break;
565         }
566       }
567     }
568     if (keys.is_valid()) {
569       UInt32 temp_count = count;
570       for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
571         key.clear();
572         restore(key_id, &key);
573         keys.insert(temp_count, key);
574         if (++temp_count >= max_num_results) {
575           break;
576         }
577       }
578     }
579     count += key_id_end - key_id_begin;
580     if (count >= max_num_results) {
581       return max_num_results;
582     }
583     node_begin = louds_pos_to_node(get_child(node_begin), node_begin);
584     node_end = louds_pos_to_node(get_child(node_end), node_end);
585   }
586   return count;
587 }
588 
589 template <typename T, typename U, typename V>
predict_depth_first_(T query,U key_ids,V keys,std::size_t max_num_results) const590 std::size_t Trie::predict_depth_first_(T query, U key_ids, V keys,
591     std::size_t max_num_results) const {
592   if (max_num_results == 0) {
593     return 0;
594   } else if (keys.is_valid()) {
595     PredictCallback<U, V> callback(key_ids, keys, max_num_results);
596     return predict_callback_(query, callback);
597   }
598 
599   UInt32 node = 0;
600   std::size_t pos = 0;
601   while (!query.ends_at(pos)) {
602     if (!predict_child<T>(node, query, pos, NULL)) {
603       return 0;
604     }
605   }
606   std::size_t count = 0;
607   if (terminal_flags_[node]) {
608     if (key_ids.is_valid()) {
609       key_ids.insert(count, node_to_key_id(node));
610     }
611     if (++count >= max_num_results) {
612       return count;
613     }
614   }
615   Cell cell;
616   cell.set_louds_pos(get_child(node));
617   if (!louds_[cell.louds_pos()]) {
618     return count;
619   }
620   cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
621   cell.set_key_id(node_to_key_id(cell.node()));
622   Vector<Cell> stack;
623   stack.push_back(cell);
624   std::size_t stack_pos = 1;
625   while (stack_pos != 0) {
626     Cell &cur = stack[stack_pos - 1];
627     if (!louds_[cur.louds_pos()]) {
628       cur.set_louds_pos(cur.louds_pos() + 1);
629       --stack_pos;
630       continue;
631     }
632     cur.set_louds_pos(cur.louds_pos() + 1);
633     if (terminal_flags_[cur.node()]) {
634       if (key_ids.is_valid()) {
635         key_ids.insert(count, cur.key_id());
636       }
637       if (++count >= max_num_results) {
638         return count;
639       }
640       cur.set_key_id(cur.key_id() + 1);
641     }
642     if (stack_pos == stack.size()) {
643       cell.set_louds_pos(get_child(cur.node()));
644       cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
645       cell.set_key_id(node_to_key_id(cell.node()));
646       stack.push_back(cell);
647     }
648     stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
649     ++stack_pos;
650   }
651   return count;
652 }
653 
654 template <typename T>
trie_prefix_match(UInt32 node,T query,std::size_t pos,std::string * key) const655 std::size_t Trie::trie_prefix_match(UInt32 node, T query,
656     std::size_t pos, std::string *key) const {
657   if (has_link(node)) {
658     std::size_t next_pos;
659     if (has_trie()) {
660       next_pos = trie_->trie_prefix_match<T>(get_link(node), query, pos, key);
661     } else {
662       next_pos = tail_prefix_match<T>(
663           node, get_link_id(node), query, pos, key);
664     }
665     if ((next_pos == mismatch()) || (next_pos == pos)) {
666       return next_pos;
667     }
668     pos = next_pos;
669   } else if (labels_[node] != query[pos]) {
670     return pos;
671   } else {
672     ++pos;
673   }
674   node = get_parent(node);
675   while (node != 0) {
676     if (query.ends_at(pos)) {
677       if (key != NULL) {
678         trie_restore(node, key);
679       }
680       return pos;
681     }
682     if (has_link(node)) {
683       std::size_t next_pos;
684       if (has_trie()) {
685         next_pos = trie_->trie_prefix_match<T>(
686             get_link(node), query, pos, key);
687       } else {
688         next_pos = tail_prefix_match<T>(
689             node, get_link_id(node), query, pos, key);
690       }
691       if ((next_pos == mismatch()) || (next_pos == pos)) {
692         return next_pos;
693       }
694       pos = next_pos;
695     } else if (labels_[node] != query[pos]) {
696       return mismatch();
697     } else {
698       ++pos;
699     }
700     node = get_parent(node);
701   }
702   return pos;
703 }
704 
705 template std::size_t Trie::trie_prefix_match<CQuery>(UInt32 node,
706     CQuery query, std::size_t pos, std::string *key) const;
707 template std::size_t Trie::trie_prefix_match<const Query &>(UInt32 node,
708     const Query &query, std::size_t pos, std::string *key) const;
709 
710 template <typename T>
tail_prefix_match(UInt32 node,UInt32 link_id,T query,std::size_t pos,std::string * key) const711 std::size_t Trie::tail_prefix_match(UInt32 node, UInt32 link_id,
712     T query, std::size_t pos, std::string *key) const {
713   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
714   const UInt8 *ptr = tail_[offset];
715   if (*ptr != query[pos]) {
716     return pos;
717   } else if (tail_.mode() == MARISA_BINARY_TAIL) {
718     const UInt32 length = (links_[link_id + 1] * 256)
719         + labels_[link_flags_.select1(link_id + 1)] - offset;
720     for (UInt32 i = 1; i < length; ++i) {
721       if (query.ends_at(pos + i)) {
722         if (key != NULL) {
723           key->append(reinterpret_cast<const char *>(ptr + i), length - i);
724         }
725         return pos + i;
726       } else if (ptr[i] != query[pos + i]) {
727         return mismatch();
728       }
729     }
730     return pos + length;
731   } else {
732     for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
733       if (query.ends_at(pos)) {
734         if (key != NULL) {
735           key->append(reinterpret_cast<const char *>(ptr));
736         }
737         return pos;
738       } else if (*ptr != query[pos]) {
739         return mismatch();
740       }
741     }
742     return pos;
743   }
744 }
745 
746 template std::size_t Trie::tail_prefix_match<CQuery>(
747     UInt32 node, UInt32 link_id,
748     CQuery query, std::size_t pos, std::string *key) const;
749 template std::size_t Trie::tail_prefix_match<const Query &>(
750     UInt32 node, UInt32 link_id,
751     const Query &query, std::size_t pos, std::string *key) const;
752 
753 }  // namespace marisa
754