1 #include <algorithm>
2 #include <stdexcept>
3
4 #include "trie.h"
5
6 namespace marisa {
7 namespace {
8
9 template <typename T, typename U>
10 class PredictCallback {
11 public:
PredictCallback(T key_ids,U keys,std::size_t max_num_results)12 PredictCallback(T key_ids, U keys, std::size_t max_num_results)
13 : key_ids_(key_ids), keys_(keys),
14 max_num_results_(max_num_results), num_results_(0) {}
PredictCallback(const PredictCallback & callback)15 PredictCallback(const PredictCallback &callback)
16 : key_ids_(callback.key_ids_), keys_(callback.keys_),
17 max_num_results_(callback.max_num_results_),
18 num_results_(callback.num_results_) {}
19
operator ()(marisa::UInt32 key_id,const std::string & key)20 bool operator()(marisa::UInt32 key_id, const std::string &key) {
21 if (key_ids_.is_valid()) {
22 key_ids_.insert(num_results_, key_id);
23 }
24 if (keys_.is_valid()) {
25 keys_.insert(num_results_, key);
26 }
27 return ++num_results_ < max_num_results_;
28 }
29
30 private:
31 T key_ids_;
32 U keys_;
33 const std::size_t max_num_results_;
34 std::size_t num_results_;
35
36 // Disallows assignment.
37 PredictCallback &operator=(const PredictCallback &);
38 };
39
40 } // namespace
41
restore(UInt32 key_id) const42 std::string Trie::restore(UInt32 key_id) const {
43 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
44 MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
45 std::string key;
46 restore_(key_id, &key);
47 return key;
48 }
49
restore(UInt32 key_id,std::string * key) const50 void Trie::restore(UInt32 key_id, std::string *key) const {
51 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
52 MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
53 MARISA_THROW_IF(key == NULL, MARISA_PARAM_ERROR);
54 restore_(key_id, key);
55 }
56
restore(UInt32 key_id,char * key_buf,std::size_t key_buf_size) const57 std::size_t Trie::restore(UInt32 key_id, char *key_buf,
58 std::size_t key_buf_size) const {
59 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
60 MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
61 MARISA_THROW_IF((key_buf == NULL) && (key_buf_size != 0),
62 MARISA_PARAM_ERROR);
63 return restore_(key_id, key_buf, key_buf_size);
64 }
65
lookup(const char * str) const66 UInt32 Trie::lookup(const char *str) const {
67 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
68 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
69 return lookup_<CQuery>(CQuery(str));
70 }
71
lookup(const char * ptr,std::size_t length) const72 UInt32 Trie::lookup(const char *ptr, std::size_t length) const {
73 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
74 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
75 return lookup_<const Query &>(Query(ptr, length));
76 }
77
find(const char * str,UInt32 * key_ids,std::size_t * key_lengths,std::size_t max_num_results) const78 std::size_t Trie::find(const char *str,
79 UInt32 *key_ids, std::size_t *key_lengths,
80 std::size_t max_num_results) const {
81 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
82 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
83 return find_<CQuery>(CQuery(str),
84 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
85 }
86
find(const char * ptr,std::size_t length,UInt32 * key_ids,std::size_t * key_lengths,std::size_t max_num_results) const87 std::size_t Trie::find(const char *ptr, std::size_t length,
88 UInt32 *key_ids, std::size_t *key_lengths,
89 std::size_t max_num_results) const {
90 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
91 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
92 return find_<const Query &>(Query(ptr, length),
93 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
94 }
95
find(const char * str,std::vector<UInt32> * key_ids,std::vector<std::size_t> * key_lengths,std::size_t max_num_results) const96 std::size_t Trie::find(const char *str,
97 std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
98 std::size_t max_num_results) const {
99 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
100 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
101 return find_<CQuery>(CQuery(str),
102 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
103 }
104
find(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::size_t> * key_lengths,std::size_t max_num_results) const105 std::size_t Trie::find(const char *ptr, std::size_t length,
106 std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
107 std::size_t max_num_results) const {
108 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
109 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
110 return find_<const Query &>(Query(ptr, length),
111 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
112 }
113
find_first(const char * str,std::size_t * key_length) const114 UInt32 Trie::find_first(const char *str,
115 std::size_t *key_length) const {
116 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
117 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
118 return find_first_<CQuery>(CQuery(str), key_length);
119 }
120
find_first(const char * ptr,std::size_t length,std::size_t * key_length) const121 UInt32 Trie::find_first(const char *ptr, std::size_t length,
122 std::size_t *key_length) const {
123 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
124 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
125 return find_first_<const Query &>(Query(ptr, length), key_length);
126 }
127
find_last(const char * str,std::size_t * key_length) const128 UInt32 Trie::find_last(const char *str,
129 std::size_t *key_length) const {
130 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
131 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
132 return find_last_<CQuery>(CQuery(str), key_length);
133 }
134
find_last(const char * ptr,std::size_t length,std::size_t * key_length) const135 UInt32 Trie::find_last(const char *ptr, std::size_t length,
136 std::size_t *key_length) const {
137 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
138 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
139 return find_last_<const Query &>(Query(ptr, length), key_length);
140 }
141
predict(const char * str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const142 std::size_t Trie::predict(const char *str,
143 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
144 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
145 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
146 return (keys == NULL) ?
147 predict_breadth_first(str, key_ids, keys, max_num_results) :
148 predict_depth_first(str, key_ids, keys, max_num_results);
149 }
150
predict(const char * ptr,std::size_t length,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const151 std::size_t Trie::predict(const char *ptr, std::size_t length,
152 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
153 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
154 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
155 return (keys == NULL) ?
156 predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
157 predict_depth_first(ptr, length, key_ids, keys, max_num_results);
158 }
159
predict(const char * str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const160 std::size_t Trie::predict(const char *str,
161 std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
162 std::size_t max_num_results) const {
163 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
164 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
165 return (keys == NULL) ?
166 predict_breadth_first(str, key_ids, keys, max_num_results) :
167 predict_depth_first(str, key_ids, keys, max_num_results);
168 }
169
predict(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const170 std::size_t Trie::predict(const char *ptr, std::size_t length,
171 std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
172 std::size_t max_num_results) const {
173 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
174 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
175 return (keys == NULL) ?
176 predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
177 predict_depth_first(ptr, length, key_ids, keys, max_num_results);
178 }
179
predict_breadth_first(const char * str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const180 std::size_t Trie::predict_breadth_first(const char *str,
181 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
182 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
183 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
184 return predict_breadth_first_<CQuery>(CQuery(str),
185 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
186 }
187
predict_breadth_first(const char * ptr,std::size_t length,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const188 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
189 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
190 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
191 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
192 return predict_breadth_first_<const Query &>(Query(ptr, length),
193 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
194 }
195
predict_breadth_first(const char * str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const196 std::size_t Trie::predict_breadth_first(const char *str,
197 std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
198 std::size_t max_num_results) const {
199 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
200 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
201 return predict_breadth_first_<CQuery>(CQuery(str),
202 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
203 }
204
predict_breadth_first(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const205 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
206 std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
207 std::size_t max_num_results) const {
208 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
209 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
210 return predict_breadth_first_<const Query &>(Query(ptr, length),
211 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
212 }
213
predict_depth_first(const char * str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const214 std::size_t Trie::predict_depth_first(const char *str,
215 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
216 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
217 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
218 return predict_depth_first_<CQuery>(CQuery(str),
219 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
220 }
221
predict_depth_first(const char * ptr,std::size_t length,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const222 std::size_t Trie::predict_depth_first(const char *ptr, std::size_t length,
223 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
224 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
225 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
226 return predict_depth_first_<const Query &>(Query(ptr, length),
227 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
228 }
229
predict_depth_first(const char * str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const230 std::size_t Trie::predict_depth_first(
231 const char *str, std::vector<UInt32> *key_ids,
232 std::vector<std::string> *keys, std::size_t max_num_results) const {
233 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
234 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
235 return predict_depth_first_<CQuery>(CQuery(str),
236 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
237 }
238
predict_depth_first(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const239 std::size_t Trie::predict_depth_first(
240 const char *ptr, std::size_t length, std::vector<UInt32> *key_ids,
241 std::vector<std::string> *keys, std::size_t max_num_results) const {
242 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
243 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
244 return predict_depth_first_<const Query &>(Query(ptr, length),
245 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
246 }
247
restore_(UInt32 key_id,std::string * key) const248 void Trie::restore_(UInt32 key_id, std::string *key) const {
249 const std::size_t start_pos = key->length();
250 UInt32 node = key_id_to_node(key_id);
251 while (node != 0) {
252 if (has_link(node)) {
253 const std::size_t prev_pos = key->length();
254 if (has_trie()) {
255 trie_->trie_restore(get_link(node), key);
256 } else {
257 tail_restore(node, key);
258 }
259 std::reverse(key->begin() + prev_pos, key->end());
260 } else {
261 *key += labels_[node];
262 }
263 node = get_parent(node);
264 }
265 std::reverse(key->begin() + start_pos, key->end());
266 }
267
trie_restore(UInt32 node,std::string * key) const268 void Trie::trie_restore(UInt32 node, std::string *key) const {
269 do {
270 if (has_link(node)) {
271 if (has_trie()) {
272 trie_->trie_restore(get_link(node), key);
273 } else {
274 tail_restore(node, key);
275 }
276 } else {
277 *key += labels_[node];
278 }
279 node = get_parent(node);
280 } while (node != 0);
281 }
282
tail_restore(UInt32 node,std::string * key) const283 void Trie::tail_restore(UInt32 node, std::string *key) const {
284 const UInt32 link_id = link_flags_.rank1(node);
285 const UInt32 offset = (links_[link_id] * 256) + labels_[node];
286 if (tail_.mode() == MARISA_BINARY_TAIL) {
287 const UInt32 length = (links_[link_id + 1] * 256)
288 + labels_[link_flags_.select1(link_id + 1)] - offset;
289 key->append(reinterpret_cast<const char *>(tail_[offset]), length);
290 } else {
291 key->append(reinterpret_cast<const char *>(tail_[offset]));
292 }
293 }
294
restore_(UInt32 key_id,char * key_buf,std::size_t key_buf_size) const295 std::size_t Trie::restore_(UInt32 key_id, char *key_buf,
296 std::size_t key_buf_size) const {
297 std::size_t pos = 0;
298 UInt32 node = key_id_to_node(key_id);
299 while (node != 0) {
300 if (has_link(node)) {
301 const std::size_t prev_pos = pos;
302 if (has_trie()) {
303 trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
304 } else {
305 tail_restore(node, key_buf, key_buf_size, pos);
306 }
307 if (pos < key_buf_size) {
308 std::reverse(key_buf + prev_pos, key_buf + pos);
309 }
310 } else {
311 if (pos < key_buf_size) {
312 key_buf[pos] = labels_[node];
313 }
314 ++pos;
315 }
316 node = get_parent(node);
317 }
318 if (pos < key_buf_size) {
319 key_buf[pos] = '\0';
320 std::reverse(key_buf, key_buf + pos);
321 }
322 return pos;
323 }
324
trie_restore(UInt32 node,char * key_buf,std::size_t key_buf_size,std::size_t & pos) const325 void Trie::trie_restore(UInt32 node, char *key_buf,
326 std::size_t key_buf_size, std::size_t &pos) const {
327 do {
328 if (has_link(node)) {
329 if (has_trie()) {
330 trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
331 } else {
332 tail_restore(node, key_buf, key_buf_size, pos);
333 }
334 } else {
335 if (pos < key_buf_size) {
336 key_buf[pos] = labels_[node];
337 }
338 ++pos;
339 }
340 node = get_parent(node);
341 } while (node != 0);
342 }
343
tail_restore(UInt32 node,char * key_buf,std::size_t key_buf_size,std::size_t & pos) const344 void Trie::tail_restore(UInt32 node, char *key_buf,
345 std::size_t key_buf_size, std::size_t &pos) const {
346 const UInt32 link_id = link_flags_.rank1(node);
347 const UInt32 offset = (links_[link_id] * 256) + labels_[node];
348 if (tail_.mode() == MARISA_BINARY_TAIL) {
349 const UInt8 *ptr = tail_[offset];
350 const UInt32 length = (links_[link_id + 1] * 256)
351 + labels_[link_flags_.select1(link_id + 1)] - offset;
352 for (UInt32 i = 0; i < length; ++i) {
353 if (pos < key_buf_size) {
354 key_buf[pos] = ptr[i];
355 }
356 ++pos;
357 }
358 } else {
359 for (const UInt8 *str = tail_[offset]; *str != '\0'; ++str) {
360 if (pos < key_buf_size) {
361 key_buf[pos] = *str;
362 }
363 ++pos;
364 }
365 }
366 }
367
368 template <typename T>
lookup_(T query) const369 UInt32 Trie::lookup_(T query) const {
370 UInt32 node = 0;
371 std::size_t pos = 0;
372 while (!query.ends_at(pos)) {
373 if (!find_child<T>(node, query, pos)) {
374 return notfound();
375 }
376 }
377 return terminal_flags_[node] ? node_to_key_id(node) : notfound();
378 }
379
380 template <typename T>
trie_match(UInt32 node,T query,std::size_t pos) const381 std::size_t Trie::trie_match(UInt32 node, T query,
382 std::size_t pos) const {
383 if (has_link(node)) {
384 std::size_t next_pos;
385 if (has_trie()) {
386 next_pos = trie_->trie_match<T>(get_link(node), query, pos);
387 } else {
388 next_pos = tail_match<T>(node, get_link_id(node), query, pos);
389 }
390 if ((next_pos == mismatch()) || (next_pos == pos)) {
391 return next_pos;
392 }
393 pos = next_pos;
394 } else if (labels_[node] != query[pos]) {
395 return pos;
396 } else {
397 ++pos;
398 }
399 node = get_parent(node);
400 while (node != 0) {
401 if (query.ends_at(pos)) {
402 return mismatch();
403 }
404 if (has_link(node)) {
405 std::size_t next_pos;
406 if (has_trie()) {
407 next_pos = trie_->trie_match<T>(get_link(node), query, pos);
408 } else {
409 next_pos = tail_match<T>(node, get_link_id(node), query, pos);
410 }
411 if ((next_pos == mismatch()) || (next_pos == pos)) {
412 return mismatch();
413 }
414 pos = next_pos;
415 } else if (labels_[node] != query[pos]) {
416 return mismatch();
417 } else {
418 ++pos;
419 }
420 node = get_parent(node);
421 }
422 return pos;
423 }
424
425 template std::size_t Trie::trie_match<CQuery>(UInt32 node,
426 CQuery query, std::size_t pos) const;
427 template std::size_t Trie::trie_match<const Query &>(UInt32 node,
428 const Query &query, std::size_t pos) const;
429
430 template <typename T>
tail_match(UInt32 node,UInt32 link_id,T query,std::size_t pos) const431 std::size_t Trie::tail_match(UInt32 node, UInt32 link_id,
432 T query, std::size_t pos) const {
433 const UInt32 offset = (links_[link_id] * 256) + labels_[node];
434 const UInt8 *ptr = tail_[offset];
435 if (*ptr != query[pos]) {
436 return pos;
437 } else if (tail_.mode() == MARISA_BINARY_TAIL) {
438 const UInt32 length = (links_[link_id + 1] * 256)
439 + labels_[link_flags_.select1(link_id + 1)] - offset;
440 for (UInt32 i = 1; i < length; ++i) {
441 if (query.ends_at(pos + i) || (ptr[i] != query[pos + i])) {
442 return mismatch();
443 }
444 }
445 return pos + length;
446 } else {
447 for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
448 if (query.ends_at(pos) || (*ptr != query[pos])) {
449 return mismatch();
450 }
451 }
452 return pos;
453 }
454 }
455
456 template std::size_t Trie::tail_match<CQuery>(UInt32 node,
457 UInt32 link_id, CQuery query, std::size_t pos) const;
458 template std::size_t Trie::tail_match<const Query &>(UInt32 node,
459 UInt32 link_id, const Query &query, std::size_t pos) const;
460
461 template <typename T, typename U, typename V>
find_(T query,U key_ids,V key_lengths,std::size_t max_num_results) const462 std::size_t Trie::find_(T query, U key_ids, V key_lengths,
463 std::size_t max_num_results) const {
464 if (max_num_results == 0) {
465 return 0;
466 }
467 std::size_t count = 0;
468 UInt32 node = 0;
469 std::size_t pos = 0;
470 do {
471 if (terminal_flags_[node]) {
472 if (key_ids.is_valid()) {
473 key_ids.insert(count, node_to_key_id(node));
474 }
475 if (key_lengths.is_valid()) {
476 key_lengths.insert(count, pos);
477 }
478 if (++count >= max_num_results) {
479 return count;
480 }
481 }
482 } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
483 return count;
484 }
485
486 template <typename T>
find_first_(T query,std::size_t * key_length) const487 UInt32 Trie::find_first_(T query, std::size_t *key_length) const {
488 UInt32 node = 0;
489 std::size_t pos = 0;
490 do {
491 if (terminal_flags_[node]) {
492 if (key_length != NULL) {
493 *key_length = pos;
494 }
495 return node_to_key_id(node);
496 }
497 } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
498 return notfound();
499 }
500
501 template <typename T>
find_last_(T query,std::size_t * key_length) const502 UInt32 Trie::find_last_(T query, std::size_t *key_length) const {
503 UInt32 node = 0;
504 UInt32 node_found = notfound();
505 std::size_t pos = 0;
506 std::size_t pos_found = mismatch();
507 do {
508 if (terminal_flags_[node]) {
509 node_found = node;
510 pos_found = pos;
511 }
512 } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
513 if (node_found != notfound()) {
514 if (key_length != NULL) {
515 *key_length = pos_found;
516 }
517 return node_to_key_id(node_found);
518 }
519 return notfound();
520 }
521
522 template <typename T, typename U, typename V>
predict_breadth_first_(T query,U key_ids,V keys,std::size_t max_num_results) const523 std::size_t Trie::predict_breadth_first_(T query, U key_ids, V keys,
524 std::size_t max_num_results) const {
525 if (max_num_results == 0) {
526 return 0;
527 }
528 UInt32 node = 0;
529 std::size_t pos = 0;
530 while (!query.ends_at(pos)) {
531 if (!predict_child<T>(node, query, pos, NULL)) {
532 return 0;
533 }
534 }
535 std::string key;
536 std::size_t count = 0;
537 if (terminal_flags_[node]) {
538 const UInt32 key_id = node_to_key_id(node);
539 if (key_ids.is_valid()) {
540 key_ids.insert(count, key_id);
541 }
542 if (keys.is_valid()) {
543 restore(key_id, &key);
544 keys.insert(count, key);
545 }
546 if (++count >= max_num_results) {
547 return count;
548 }
549 }
550 const UInt32 louds_pos = get_child(node);
551 if (!louds_[louds_pos]) {
552 return count;
553 }
554 UInt32 node_begin = louds_pos_to_node(louds_pos, node);
555 UInt32 node_end = louds_pos_to_node(get_child(node + 1), node + 1);
556 while (node_begin < node_end) {
557 const UInt32 key_id_begin = node_to_key_id(node_begin);
558 const UInt32 key_id_end = node_to_key_id(node_end);
559 if (key_ids.is_valid()) {
560 UInt32 temp_count = count;
561 for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
562 key_ids.insert(temp_count, key_id);
563 if (++temp_count >= max_num_results) {
564 break;
565 }
566 }
567 }
568 if (keys.is_valid()) {
569 UInt32 temp_count = count;
570 for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
571 key.clear();
572 restore(key_id, &key);
573 keys.insert(temp_count, key);
574 if (++temp_count >= max_num_results) {
575 break;
576 }
577 }
578 }
579 count += key_id_end - key_id_begin;
580 if (count >= max_num_results) {
581 return max_num_results;
582 }
583 node_begin = louds_pos_to_node(get_child(node_begin), node_begin);
584 node_end = louds_pos_to_node(get_child(node_end), node_end);
585 }
586 return count;
587 }
588
589 template <typename T, typename U, typename V>
predict_depth_first_(T query,U key_ids,V keys,std::size_t max_num_results) const590 std::size_t Trie::predict_depth_first_(T query, U key_ids, V keys,
591 std::size_t max_num_results) const {
592 if (max_num_results == 0) {
593 return 0;
594 } else if (keys.is_valid()) {
595 PredictCallback<U, V> callback(key_ids, keys, max_num_results);
596 return predict_callback_(query, callback);
597 }
598
599 UInt32 node = 0;
600 std::size_t pos = 0;
601 while (!query.ends_at(pos)) {
602 if (!predict_child<T>(node, query, pos, NULL)) {
603 return 0;
604 }
605 }
606 std::size_t count = 0;
607 if (terminal_flags_[node]) {
608 if (key_ids.is_valid()) {
609 key_ids.insert(count, node_to_key_id(node));
610 }
611 if (++count >= max_num_results) {
612 return count;
613 }
614 }
615 Cell cell;
616 cell.set_louds_pos(get_child(node));
617 if (!louds_[cell.louds_pos()]) {
618 return count;
619 }
620 cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
621 cell.set_key_id(node_to_key_id(cell.node()));
622 Vector<Cell> stack;
623 stack.push_back(cell);
624 std::size_t stack_pos = 1;
625 while (stack_pos != 0) {
626 Cell &cur = stack[stack_pos - 1];
627 if (!louds_[cur.louds_pos()]) {
628 cur.set_louds_pos(cur.louds_pos() + 1);
629 --stack_pos;
630 continue;
631 }
632 cur.set_louds_pos(cur.louds_pos() + 1);
633 if (terminal_flags_[cur.node()]) {
634 if (key_ids.is_valid()) {
635 key_ids.insert(count, cur.key_id());
636 }
637 if (++count >= max_num_results) {
638 return count;
639 }
640 cur.set_key_id(cur.key_id() + 1);
641 }
642 if (stack_pos == stack.size()) {
643 cell.set_louds_pos(get_child(cur.node()));
644 cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
645 cell.set_key_id(node_to_key_id(cell.node()));
646 stack.push_back(cell);
647 }
648 stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
649 ++stack_pos;
650 }
651 return count;
652 }
653
654 template <typename T>
trie_prefix_match(UInt32 node,T query,std::size_t pos,std::string * key) const655 std::size_t Trie::trie_prefix_match(UInt32 node, T query,
656 std::size_t pos, std::string *key) const {
657 if (has_link(node)) {
658 std::size_t next_pos;
659 if (has_trie()) {
660 next_pos = trie_->trie_prefix_match<T>(get_link(node), query, pos, key);
661 } else {
662 next_pos = tail_prefix_match<T>(
663 node, get_link_id(node), query, pos, key);
664 }
665 if ((next_pos == mismatch()) || (next_pos == pos)) {
666 return next_pos;
667 }
668 pos = next_pos;
669 } else if (labels_[node] != query[pos]) {
670 return pos;
671 } else {
672 ++pos;
673 }
674 node = get_parent(node);
675 while (node != 0) {
676 if (query.ends_at(pos)) {
677 if (key != NULL) {
678 trie_restore(node, key);
679 }
680 return pos;
681 }
682 if (has_link(node)) {
683 std::size_t next_pos;
684 if (has_trie()) {
685 next_pos = trie_->trie_prefix_match<T>(
686 get_link(node), query, pos, key);
687 } else {
688 next_pos = tail_prefix_match<T>(
689 node, get_link_id(node), query, pos, key);
690 }
691 if ((next_pos == mismatch()) || (next_pos == pos)) {
692 return next_pos;
693 }
694 pos = next_pos;
695 } else if (labels_[node] != query[pos]) {
696 return mismatch();
697 } else {
698 ++pos;
699 }
700 node = get_parent(node);
701 }
702 return pos;
703 }
704
705 template std::size_t Trie::trie_prefix_match<CQuery>(UInt32 node,
706 CQuery query, std::size_t pos, std::string *key) const;
707 template std::size_t Trie::trie_prefix_match<const Query &>(UInt32 node,
708 const Query &query, std::size_t pos, std::string *key) const;
709
710 template <typename T>
tail_prefix_match(UInt32 node,UInt32 link_id,T query,std::size_t pos,std::string * key) const711 std::size_t Trie::tail_prefix_match(UInt32 node, UInt32 link_id,
712 T query, std::size_t pos, std::string *key) const {
713 const UInt32 offset = (links_[link_id] * 256) + labels_[node];
714 const UInt8 *ptr = tail_[offset];
715 if (*ptr != query[pos]) {
716 return pos;
717 } else if (tail_.mode() == MARISA_BINARY_TAIL) {
718 const UInt32 length = (links_[link_id + 1] * 256)
719 + labels_[link_flags_.select1(link_id + 1)] - offset;
720 for (UInt32 i = 1; i < length; ++i) {
721 if (query.ends_at(pos + i)) {
722 if (key != NULL) {
723 key->append(reinterpret_cast<const char *>(ptr + i), length - i);
724 }
725 return pos + i;
726 } else if (ptr[i] != query[pos + i]) {
727 return mismatch();
728 }
729 }
730 return pos + length;
731 } else {
732 for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
733 if (query.ends_at(pos)) {
734 if (key != NULL) {
735 key->append(reinterpret_cast<const char *>(ptr));
736 }
737 return pos;
738 } else if (*ptr != query[pos]) {
739 return mismatch();
740 }
741 }
742 return pos;
743 }
744 }
745
746 template std::size_t Trie::tail_prefix_match<CQuery>(
747 UInt32 node, UInt32 link_id,
748 CQuery query, std::size_t pos, std::string *key) const;
749 template std::size_t Trie::tail_prefix_match<const Query &>(
750 UInt32 node, UInt32 link_id,
751 const Query &query, std::size_t pos, std::string *key) const;
752
753 } // namespace marisa
754