1 #include "trie.h"
2
3 extern "C" {
4
5 namespace {
6
7 class FindCallback {
8 public:
9 typedef int (*Func)(void *, marisa_alpha_uint32, size_t);
10
FindCallback(Func func,void * first_arg)11 FindCallback(Func func, void *first_arg)
12 : func_(func), first_arg_(first_arg) {}
FindCallback(const FindCallback & callback)13 FindCallback(const FindCallback &callback)
14 : func_(callback.func_), first_arg_(callback.first_arg_) {}
15
operator ()(marisa_alpha::UInt32 key_id,std::size_t key_length) const16 bool operator()(marisa_alpha::UInt32 key_id, std::size_t key_length) const {
17 return func_(first_arg_, key_id, key_length) != 0;
18 }
19
20 private:
21 Func func_;
22 void *first_arg_;
23
24 // Disallows assignment.
25 FindCallback &operator=(const FindCallback &);
26 };
27
28 class PredictCallback {
29 public:
30 typedef int (*Func)(void *, marisa_alpha_uint32, const char *, size_t);
31
PredictCallback(Func func,void * first_arg)32 PredictCallback(Func func, void *first_arg)
33 : func_(func), first_arg_(first_arg) {}
PredictCallback(const PredictCallback & callback)34 PredictCallback(const PredictCallback &callback)
35 : func_(callback.func_), first_arg_(callback.first_arg_) {}
36
operator ()(marisa_alpha::UInt32 key_id,const std::string & key) const37 bool operator()(marisa_alpha::UInt32 key_id, const std::string &key) const {
38 return func_(first_arg_, key_id, key.c_str(), key.length()) != 0;
39 }
40
41 private:
42 Func func_;
43 void *first_arg_;
44
45 // Disallows assignment.
46 PredictCallback &operator=(const PredictCallback &);
47 };
48
49 } // namespace
50
51 struct marisa_alpha_trie_ {
52 public:
marisa_alpha_trie_marisa_alpha_trie_53 marisa_alpha_trie_() : trie(), mapper() {}
54
55 marisa_alpha::Trie trie;
56 marisa_alpha::Mapper mapper;
57
58 private:
59 // Disallows copy and assignment.
60 marisa_alpha_trie_(const marisa_alpha_trie_ &);
61 marisa_alpha_trie_ &operator=(const marisa_alpha_trie_ &);
62 };
63
marisa_alpha_init(marisa_alpha_trie ** h)64 marisa_alpha_status marisa_alpha_init(marisa_alpha_trie **h) {
65 if ((h == NULL) || (*h != NULL)) {
66 return MARISA_ALPHA_HANDLE_ERROR;
67 }
68 *h = new (std::nothrow) marisa_alpha_trie_();
69 return (*h != NULL) ? MARISA_ALPHA_OK : MARISA_ALPHA_MEMORY_ERROR;
70 }
71
marisa_alpha_end(marisa_alpha_trie * h)72 marisa_alpha_status marisa_alpha_end(marisa_alpha_trie *h) {
73 if (h == NULL) {
74 return MARISA_ALPHA_HANDLE_ERROR;
75 }
76 delete h;
77 return MARISA_ALPHA_OK;
78 }
79
marisa_alpha_build(marisa_alpha_trie * h,const char * const * keys,size_t num_keys,const size_t * key_lengths,const double * key_weights,marisa_alpha_uint32 * key_ids,int flags)80 marisa_alpha_status marisa_alpha_build(marisa_alpha_trie *h,
81 const char * const *keys, size_t num_keys, const size_t *key_lengths,
82 const double *key_weights, marisa_alpha_uint32 *key_ids, int flags) try {
83 if (h == NULL) {
84 return MARISA_ALPHA_HANDLE_ERROR;
85 }
86 h->trie.build(keys, num_keys, key_lengths, key_weights, key_ids, flags);
87 h->mapper.clear();
88 return MARISA_ALPHA_OK;
89 } catch (const marisa_alpha::Exception &ex) {
90 return ex.status();
91 }
92
marisa_alpha_mmap(marisa_alpha_trie * h,const char * filename,long offset,int whence)93 marisa_alpha_status marisa_alpha_mmap(marisa_alpha_trie *h,
94 const char *filename, long offset, int whence) try {
95 if (h == NULL) {
96 return MARISA_ALPHA_HANDLE_ERROR;
97 }
98 h->trie.mmap(&h->mapper, filename, offset, whence);
99 return MARISA_ALPHA_OK;
100 } catch (const marisa_alpha::Exception &ex) {
101 return ex.status();
102 }
103
marisa_alpha_map(marisa_alpha_trie * h,const void * ptr,size_t size)104 marisa_alpha_status marisa_alpha_map(marisa_alpha_trie *h, const void *ptr,
105 size_t size) try {
106 if (h == NULL) {
107 return MARISA_ALPHA_HANDLE_ERROR;
108 }
109 h->trie.map(ptr, size);
110 h->mapper.clear();
111 return MARISA_ALPHA_OK;
112 } catch (const marisa_alpha::Exception &ex) {
113 return ex.status();
114 }
115
marisa_alpha_load(marisa_alpha_trie * h,const char * filename,long offset,int whence)116 marisa_alpha_status marisa_alpha_load(marisa_alpha_trie *h,
117 const char *filename, long offset, int whence) try {
118 if (h == NULL) {
119 return MARISA_ALPHA_HANDLE_ERROR;
120 }
121 h->trie.load(filename, offset, whence);
122 h->mapper.clear();
123 return MARISA_ALPHA_OK;
124 } catch (const marisa_alpha::Exception &ex) {
125 return ex.status();
126 }
127
marisa_alpha_fread(marisa_alpha_trie * h,FILE * file)128 marisa_alpha_status marisa_alpha_fread(marisa_alpha_trie *h, FILE *file) try {
129 if (h == NULL) {
130 return MARISA_ALPHA_HANDLE_ERROR;
131 }
132 h->trie.fread(file);
133 h->mapper.clear();
134 return MARISA_ALPHA_OK;
135 } catch (const marisa_alpha::Exception &ex) {
136 return ex.status();
137 }
138
marisa_alpha_read(marisa_alpha_trie * h,int fd)139 marisa_alpha_status marisa_alpha_read(marisa_alpha_trie *h, int fd) try {
140 if (h == NULL) {
141 return MARISA_ALPHA_HANDLE_ERROR;
142 }
143 h->trie.read(fd);
144 h->mapper.clear();
145 return MARISA_ALPHA_OK;
146 } catch (const marisa_alpha::Exception &ex) {
147 return ex.status();
148 }
149
marisa_alpha_save(const marisa_alpha_trie * h,const char * filename,int trunc_flag,long offset,int whence)150 marisa_alpha_status marisa_alpha_save(const marisa_alpha_trie *h,
151 const char *filename, int trunc_flag, long offset, int whence) try {
152 if (h == NULL) {
153 return MARISA_ALPHA_HANDLE_ERROR;
154 }
155 h->trie.save(filename, trunc_flag != 0, offset, whence);
156 return MARISA_ALPHA_OK;
157 } catch (const marisa_alpha::Exception &ex) {
158 return ex.status();
159 }
160
marisa_alpha_fwrite(const marisa_alpha_trie * h,FILE * file)161 marisa_alpha_status marisa_alpha_fwrite(const marisa_alpha_trie *h,
162 FILE *file) try {
163 if (h == NULL) {
164 return MARISA_ALPHA_HANDLE_ERROR;
165 }
166 h->trie.fwrite(file);
167 return MARISA_ALPHA_OK;
168 } catch (const marisa_alpha::Exception &ex) {
169 return ex.status();
170 }
171
marisa_alpha_write(const marisa_alpha_trie * h,int fd)172 marisa_alpha_status marisa_alpha_write(const marisa_alpha_trie *h, int fd) try {
173 if (h == NULL) {
174 return MARISA_ALPHA_HANDLE_ERROR;
175 }
176 h->trie.write(fd);
177 return MARISA_ALPHA_OK;
178 } catch (const marisa_alpha::Exception &ex) {
179 return ex.status();
180 }
181
marisa_alpha_restore(const marisa_alpha_trie * h,marisa_alpha_uint32 key_id,char * key_buf,size_t key_buf_size,size_t * key_length)182 marisa_alpha_status marisa_alpha_restore(const marisa_alpha_trie *h,
183 marisa_alpha_uint32 key_id, char *key_buf, size_t key_buf_size,
184 size_t *key_length) try {
185 if (h == NULL) {
186 return MARISA_ALPHA_HANDLE_ERROR;
187 } else if (key_length == NULL) {
188 return MARISA_ALPHA_PARAM_ERROR;
189 }
190 *key_length = h->trie.restore(key_id, key_buf, key_buf_size);
191 return MARISA_ALPHA_OK;
192 } catch (const marisa_alpha::Exception &ex) {
193 return ex.status();
194 }
195
marisa_alpha_lookup(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_id)196 marisa_alpha_status marisa_alpha_lookup(const marisa_alpha_trie *h,
197 const char *ptr, size_t length, marisa_alpha_uint32 *key_id) try {
198 if (h == NULL) {
199 return MARISA_ALPHA_HANDLE_ERROR;
200 } else if (key_id == NULL) {
201 return MARISA_ALPHA_PARAM_ERROR;
202 }
203 if (length == MARISA_ALPHA_ZERO_TERMINATED) {
204 *key_id = h->trie.lookup(ptr);
205 } else {
206 *key_id = h->trie.lookup(ptr, length);
207 }
208 return MARISA_ALPHA_OK;
209 } catch (const marisa_alpha::Exception &ex) {
210 return ex.status();
211 }
212
marisa_alpha_find(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_ids,size_t * key_lengths,size_t max_num_results,size_t * num_results)213 marisa_alpha_status marisa_alpha_find(const marisa_alpha_trie *h,
214 const char *ptr, size_t length,
215 marisa_alpha_uint32 *key_ids, size_t *key_lengths,
216 size_t max_num_results, size_t *num_results) try {
217 if (h == NULL) {
218 return MARISA_ALPHA_HANDLE_ERROR;
219 } else if (num_results == NULL) {
220 return MARISA_ALPHA_PARAM_ERROR;
221 }
222 if (length == MARISA_ALPHA_ZERO_TERMINATED) {
223 *num_results = h->trie.find(ptr, key_ids, key_lengths, max_num_results);
224 } else {
225 *num_results = h->trie.find(ptr, length,
226 key_ids, key_lengths, max_num_results);
227 }
228 return MARISA_ALPHA_OK;
229 } catch (const marisa_alpha::Exception &ex) {
230 return ex.status();
231 }
232
marisa_alpha_find_first(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_id,size_t * key_length)233 marisa_alpha_status marisa_alpha_find_first(const marisa_alpha_trie *h,
234 const char *ptr, size_t length,
235 marisa_alpha_uint32 *key_id, size_t *key_length) {
236 if (h == NULL) {
237 return MARISA_ALPHA_HANDLE_ERROR;
238 } else if (key_id == NULL) {
239 return MARISA_ALPHA_PARAM_ERROR;
240 }
241 if (length == MARISA_ALPHA_ZERO_TERMINATED) {
242 *key_id = h->trie.find_first(ptr, key_length);
243 } else {
244 *key_id = h->trie.find_first(ptr, length, key_length);
245 }
246 return MARISA_ALPHA_OK;
247 }
248
marisa_alpha_find_last(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_id,size_t * key_length)249 marisa_alpha_status marisa_alpha_find_last(const marisa_alpha_trie *h,
250 const char *ptr, size_t length,
251 marisa_alpha_uint32 *key_id, size_t *key_length) {
252 if (h == NULL) {
253 return MARISA_ALPHA_HANDLE_ERROR;
254 } else if (key_id == NULL) {
255 return MARISA_ALPHA_PARAM_ERROR;
256 }
257 if (length == MARISA_ALPHA_ZERO_TERMINATED) {
258 *key_id = h->trie.find_last(ptr, key_length);
259 } else {
260 *key_id = h->trie.find_last(ptr, length, key_length);
261 }
262 return MARISA_ALPHA_OK;
263 }
264
marisa_alpha_find_callback(const marisa_alpha_trie * h,const char * ptr,size_t length,int (* callback)(void *,marisa_alpha_uint32,size_t),void * first_arg_to_callback)265 marisa_alpha_status marisa_alpha_find_callback(const marisa_alpha_trie *h,
266 const char *ptr, size_t length,
267 int (*callback)(void *, marisa_alpha_uint32, size_t),
268 void *first_arg_to_callback) try {
269 if (h == NULL) {
270 return MARISA_ALPHA_HANDLE_ERROR;
271 } else if (callback == NULL) {
272 return MARISA_ALPHA_PARAM_ERROR;
273 }
274 if (length == MARISA_ALPHA_ZERO_TERMINATED) {
275 h->trie.find_callback(ptr,
276 ::FindCallback(callback, first_arg_to_callback));
277 } else {
278 h->trie.find_callback(ptr, length,
279 ::FindCallback(callback, first_arg_to_callback));
280 }
281 return MARISA_ALPHA_OK;
282 } catch (const marisa_alpha::Exception &ex) {
283 return ex.status();
284 }
285
marisa_alpha_predict(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_ids,size_t max_num_results,size_t * num_results)286 marisa_alpha_status marisa_alpha_predict(const marisa_alpha_trie *h,
287 const char *ptr, size_t length, marisa_alpha_uint32 *key_ids,
288 size_t max_num_results, size_t *num_results) {
289 return marisa_alpha_predict_breadth_first(h, ptr, length,
290 key_ids, max_num_results, num_results);
291 }
292
marisa_alpha_predict_breadth_first(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_ids,size_t max_num_results,size_t * num_results)293 marisa_alpha_status marisa_alpha_predict_breadth_first(
294 const marisa_alpha_trie *h, const char *ptr, size_t length,
295 marisa_alpha_uint32 *key_ids, size_t max_num_results,
296 size_t *num_results) try {
297 if (h == NULL) {
298 return MARISA_ALPHA_HANDLE_ERROR;
299 } else if (num_results == NULL) {
300 return MARISA_ALPHA_PARAM_ERROR;
301 }
302 if (length == MARISA_ALPHA_ZERO_TERMINATED) {
303 *num_results = h->trie.predict_breadth_first(
304 ptr, key_ids, NULL, max_num_results);
305 } else {
306 *num_results = h->trie.predict_breadth_first(
307 ptr, length, key_ids, NULL, max_num_results);
308 }
309 return MARISA_ALPHA_OK;
310 } catch (const marisa_alpha::Exception &ex) {
311 return ex.status();
312 }
313
marisa_alpha_predict_depth_first(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_ids,size_t max_num_results,size_t * num_results)314 marisa_alpha_status marisa_alpha_predict_depth_first(
315 const marisa_alpha_trie *h, const char *ptr, size_t length,
316 marisa_alpha_uint32 *key_ids, size_t max_num_results,
317 size_t *num_results) try {
318 if (h == NULL) {
319 return MARISA_ALPHA_HANDLE_ERROR;
320 } else if (num_results == NULL) {
321 return MARISA_ALPHA_PARAM_ERROR;
322 }
323 if (length == MARISA_ALPHA_ZERO_TERMINATED) {
324 *num_results = h->trie.predict_depth_first(
325 ptr, key_ids, NULL, max_num_results);
326 } else {
327 *num_results = h->trie.predict_depth_first(
328 ptr, length, key_ids, NULL, max_num_results);
329 }
330 return MARISA_ALPHA_OK;
331 } catch (const marisa_alpha::Exception &ex) {
332 return ex.status();
333 }
334
marisa_alpha_predict_callback(const marisa_alpha_trie * h,const char * ptr,size_t length,int (* callback)(void *,marisa_alpha_uint32,const char *,size_t),void * first_arg_to_callback)335 marisa_alpha_status marisa_alpha_predict_callback(const marisa_alpha_trie *h,
336 const char *ptr, size_t length,
337 int (*callback)(void *, marisa_alpha_uint32, const char *, size_t),
338 void *first_arg_to_callback) try {
339 if (h == NULL) {
340 return MARISA_ALPHA_HANDLE_ERROR;
341 } else if (callback == NULL) {
342 return MARISA_ALPHA_PARAM_ERROR;
343 }
344 if (length == MARISA_ALPHA_ZERO_TERMINATED) {
345 h->trie.predict_callback(ptr,
346 ::PredictCallback(callback, first_arg_to_callback));
347 } else {
348 h->trie.predict_callback(ptr, length,
349 ::PredictCallback(callback, first_arg_to_callback));
350 }
351 return MARISA_ALPHA_OK;
352 } catch (const marisa_alpha::Exception &ex) {
353 return ex.status();
354 }
355
marisa_alpha_get_num_tries(const marisa_alpha_trie * h)356 size_t marisa_alpha_get_num_tries(const marisa_alpha_trie *h) {
357 return (h != NULL) ? h->trie.num_tries() : 0;
358 }
359
marisa_alpha_get_num_keys(const marisa_alpha_trie * h)360 size_t marisa_alpha_get_num_keys(const marisa_alpha_trie *h) {
361 return (h != NULL) ? h->trie.num_keys() : 0;
362 }
363
marisa_alpha_get_num_nodes(const marisa_alpha_trie * h)364 size_t marisa_alpha_get_num_nodes(const marisa_alpha_trie *h) {
365 return (h != NULL) ? h->trie.num_nodes() : 0;
366 }
367
marisa_alpha_get_total_size(const marisa_alpha_trie * h)368 size_t marisa_alpha_get_total_size(const marisa_alpha_trie *h) {
369 return (h != NULL) ? h->trie.total_size() : 0;
370 }
371
marisa_alpha_clear(marisa_alpha_trie * h)372 marisa_alpha_status marisa_alpha_clear(marisa_alpha_trie *h) {
373 if (h == NULL) {
374 return MARISA_ALPHA_HANDLE_ERROR;
375 }
376 h->trie.clear();
377 h->mapper.clear();
378 return MARISA_ALPHA_OK;
379 }
380
381 } // extern "C"
382