1 #include "trie.h"
2 
3 extern "C" {
4 
5 namespace {
6 
7 class FindCallback {
8  public:
9   typedef int (*Func)(void *, marisa_alpha_uint32, size_t);
10 
FindCallback(Func func,void * first_arg)11   FindCallback(Func func, void *first_arg)
12       : func_(func), first_arg_(first_arg) {}
FindCallback(const FindCallback & callback)13   FindCallback(const FindCallback &callback)
14       : func_(callback.func_), first_arg_(callback.first_arg_) {}
15 
operator ()(marisa_alpha::UInt32 key_id,std::size_t key_length) const16   bool operator()(marisa_alpha::UInt32 key_id, std::size_t key_length) const {
17     return func_(first_arg_, key_id, key_length) != 0;
18   }
19 
20  private:
21   Func func_;
22   void *first_arg_;
23 
24   // Disallows assignment.
25   FindCallback &operator=(const FindCallback &);
26 };
27 
28 class PredictCallback {
29  public:
30   typedef int (*Func)(void *, marisa_alpha_uint32, const char *, size_t);
31 
PredictCallback(Func func,void * first_arg)32   PredictCallback(Func func, void *first_arg)
33       : func_(func), first_arg_(first_arg) {}
PredictCallback(const PredictCallback & callback)34   PredictCallback(const PredictCallback &callback)
35       : func_(callback.func_), first_arg_(callback.first_arg_) {}
36 
operator ()(marisa_alpha::UInt32 key_id,const std::string & key) const37   bool operator()(marisa_alpha::UInt32 key_id, const std::string &key) const {
38     return func_(first_arg_, key_id, key.c_str(), key.length()) != 0;
39   }
40 
41  private:
42   Func func_;
43   void *first_arg_;
44 
45   // Disallows assignment.
46   PredictCallback &operator=(const PredictCallback &);
47 };
48 
49 }  // namespace
50 
51 struct marisa_alpha_trie_ {
52  public:
marisa_alpha_trie_marisa_alpha_trie_53   marisa_alpha_trie_() : trie(), mapper() {}
54 
55   marisa_alpha::Trie trie;
56   marisa_alpha::Mapper mapper;
57 
58  private:
59   // Disallows copy and assignment.
60   marisa_alpha_trie_(const marisa_alpha_trie_ &);
61   marisa_alpha_trie_ &operator=(const marisa_alpha_trie_ &);
62 };
63 
marisa_alpha_init(marisa_alpha_trie ** h)64 marisa_alpha_status marisa_alpha_init(marisa_alpha_trie **h) {
65   if ((h == NULL) || (*h != NULL)) {
66     return MARISA_ALPHA_HANDLE_ERROR;
67   }
68   *h = new (std::nothrow) marisa_alpha_trie_();
69   return (*h != NULL) ? MARISA_ALPHA_OK : MARISA_ALPHA_MEMORY_ERROR;
70 }
71 
marisa_alpha_end(marisa_alpha_trie * h)72 marisa_alpha_status marisa_alpha_end(marisa_alpha_trie *h) {
73   if (h == NULL) {
74     return MARISA_ALPHA_HANDLE_ERROR;
75   }
76   delete h;
77   return MARISA_ALPHA_OK;
78 }
79 
marisa_alpha_build(marisa_alpha_trie * h,const char * const * keys,size_t num_keys,const size_t * key_lengths,const double * key_weights,marisa_alpha_uint32 * key_ids,int flags)80 marisa_alpha_status marisa_alpha_build(marisa_alpha_trie *h,
81     const char * const *keys, size_t num_keys, const size_t *key_lengths,
82     const double *key_weights, marisa_alpha_uint32 *key_ids, int flags) try {
83   if (h == NULL) {
84     return MARISA_ALPHA_HANDLE_ERROR;
85   }
86   h->trie.build(keys, num_keys, key_lengths, key_weights, key_ids, flags);
87   h->mapper.clear();
88   return MARISA_ALPHA_OK;
89 } catch (const marisa_alpha::Exception &ex) {
90   return ex.status();
91 }
92 
marisa_alpha_mmap(marisa_alpha_trie * h,const char * filename,long offset,int whence)93 marisa_alpha_status marisa_alpha_mmap(marisa_alpha_trie *h,
94     const char *filename, long offset, int whence) try {
95   if (h == NULL) {
96     return MARISA_ALPHA_HANDLE_ERROR;
97   }
98   h->trie.mmap(&h->mapper, filename, offset, whence);
99   return MARISA_ALPHA_OK;
100 } catch (const marisa_alpha::Exception &ex) {
101   return ex.status();
102 }
103 
marisa_alpha_map(marisa_alpha_trie * h,const void * ptr,size_t size)104 marisa_alpha_status marisa_alpha_map(marisa_alpha_trie *h, const void *ptr,
105     size_t size) try {
106   if (h == NULL) {
107     return MARISA_ALPHA_HANDLE_ERROR;
108   }
109   h->trie.map(ptr, size);
110   h->mapper.clear();
111   return MARISA_ALPHA_OK;
112 } catch (const marisa_alpha::Exception &ex) {
113   return ex.status();
114 }
115 
marisa_alpha_load(marisa_alpha_trie * h,const char * filename,long offset,int whence)116 marisa_alpha_status marisa_alpha_load(marisa_alpha_trie *h,
117     const char *filename, long offset, int whence) try {
118   if (h == NULL) {
119     return MARISA_ALPHA_HANDLE_ERROR;
120   }
121   h->trie.load(filename, offset, whence);
122   h->mapper.clear();
123   return MARISA_ALPHA_OK;
124 } catch (const marisa_alpha::Exception &ex) {
125   return ex.status();
126 }
127 
marisa_alpha_fread(marisa_alpha_trie * h,FILE * file)128 marisa_alpha_status marisa_alpha_fread(marisa_alpha_trie *h, FILE *file) try {
129   if (h == NULL) {
130     return MARISA_ALPHA_HANDLE_ERROR;
131   }
132   h->trie.fread(file);
133   h->mapper.clear();
134   return MARISA_ALPHA_OK;
135 } catch (const marisa_alpha::Exception &ex) {
136   return ex.status();
137 }
138 
marisa_alpha_read(marisa_alpha_trie * h,int fd)139 marisa_alpha_status marisa_alpha_read(marisa_alpha_trie *h, int fd) try {
140   if (h == NULL) {
141     return MARISA_ALPHA_HANDLE_ERROR;
142   }
143   h->trie.read(fd);
144   h->mapper.clear();
145   return MARISA_ALPHA_OK;
146 } catch (const marisa_alpha::Exception &ex) {
147   return ex.status();
148 }
149 
marisa_alpha_save(const marisa_alpha_trie * h,const char * filename,int trunc_flag,long offset,int whence)150 marisa_alpha_status marisa_alpha_save(const marisa_alpha_trie *h,
151     const char *filename, int trunc_flag, long offset, int whence) try {
152   if (h == NULL) {
153     return MARISA_ALPHA_HANDLE_ERROR;
154   }
155   h->trie.save(filename, trunc_flag != 0, offset, whence);
156   return MARISA_ALPHA_OK;
157 } catch (const marisa_alpha::Exception &ex) {
158   return ex.status();
159 }
160 
marisa_alpha_fwrite(const marisa_alpha_trie * h,FILE * file)161 marisa_alpha_status marisa_alpha_fwrite(const marisa_alpha_trie *h,
162     FILE *file) try {
163   if (h == NULL) {
164     return MARISA_ALPHA_HANDLE_ERROR;
165   }
166   h->trie.fwrite(file);
167   return MARISA_ALPHA_OK;
168 } catch (const marisa_alpha::Exception &ex) {
169   return ex.status();
170 }
171 
marisa_alpha_write(const marisa_alpha_trie * h,int fd)172 marisa_alpha_status marisa_alpha_write(const marisa_alpha_trie *h, int fd) try {
173   if (h == NULL) {
174     return MARISA_ALPHA_HANDLE_ERROR;
175   }
176   h->trie.write(fd);
177   return MARISA_ALPHA_OK;
178 } catch (const marisa_alpha::Exception &ex) {
179   return ex.status();
180 }
181 
marisa_alpha_restore(const marisa_alpha_trie * h,marisa_alpha_uint32 key_id,char * key_buf,size_t key_buf_size,size_t * key_length)182 marisa_alpha_status marisa_alpha_restore(const marisa_alpha_trie *h,
183     marisa_alpha_uint32 key_id, char *key_buf, size_t key_buf_size,
184     size_t *key_length) try {
185   if (h == NULL) {
186     return MARISA_ALPHA_HANDLE_ERROR;
187   } else if (key_length == NULL) {
188     return MARISA_ALPHA_PARAM_ERROR;
189   }
190   *key_length = h->trie.restore(key_id, key_buf, key_buf_size);
191   return MARISA_ALPHA_OK;
192 } catch (const marisa_alpha::Exception &ex) {
193   return ex.status();
194 }
195 
marisa_alpha_lookup(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_id)196 marisa_alpha_status marisa_alpha_lookup(const marisa_alpha_trie *h,
197     const char *ptr, size_t length, marisa_alpha_uint32 *key_id) try {
198   if (h == NULL) {
199     return MARISA_ALPHA_HANDLE_ERROR;
200   } else if (key_id == NULL) {
201     return MARISA_ALPHA_PARAM_ERROR;
202   }
203   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
204     *key_id = h->trie.lookup(ptr);
205   } else {
206     *key_id = h->trie.lookup(ptr, length);
207   }
208   return MARISA_ALPHA_OK;
209 } catch (const marisa_alpha::Exception &ex) {
210   return ex.status();
211 }
212 
marisa_alpha_find(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_ids,size_t * key_lengths,size_t max_num_results,size_t * num_results)213 marisa_alpha_status marisa_alpha_find(const marisa_alpha_trie *h,
214     const char *ptr, size_t length,
215     marisa_alpha_uint32 *key_ids, size_t *key_lengths,
216     size_t max_num_results, size_t *num_results) try {
217   if (h == NULL) {
218     return MARISA_ALPHA_HANDLE_ERROR;
219   } else if (num_results == NULL) {
220     return MARISA_ALPHA_PARAM_ERROR;
221   }
222   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
223     *num_results = h->trie.find(ptr, key_ids, key_lengths, max_num_results);
224   } else {
225     *num_results = h->trie.find(ptr, length,
226         key_ids, key_lengths, max_num_results);
227   }
228   return MARISA_ALPHA_OK;
229 } catch (const marisa_alpha::Exception &ex) {
230   return ex.status();
231 }
232 
marisa_alpha_find_first(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_id,size_t * key_length)233 marisa_alpha_status marisa_alpha_find_first(const marisa_alpha_trie *h,
234     const char *ptr, size_t length,
235     marisa_alpha_uint32 *key_id, size_t *key_length) {
236   if (h == NULL) {
237     return MARISA_ALPHA_HANDLE_ERROR;
238   } else if (key_id == NULL) {
239     return MARISA_ALPHA_PARAM_ERROR;
240   }
241   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
242     *key_id = h->trie.find_first(ptr, key_length);
243   } else {
244     *key_id = h->trie.find_first(ptr, length, key_length);
245   }
246   return MARISA_ALPHA_OK;
247 }
248 
marisa_alpha_find_last(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_id,size_t * key_length)249 marisa_alpha_status marisa_alpha_find_last(const marisa_alpha_trie *h,
250     const char *ptr, size_t length,
251     marisa_alpha_uint32 *key_id, size_t *key_length) {
252   if (h == NULL) {
253     return MARISA_ALPHA_HANDLE_ERROR;
254   } else if (key_id == NULL) {
255     return MARISA_ALPHA_PARAM_ERROR;
256   }
257   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
258     *key_id = h->trie.find_last(ptr, key_length);
259   } else {
260     *key_id = h->trie.find_last(ptr, length, key_length);
261   }
262   return MARISA_ALPHA_OK;
263 }
264 
marisa_alpha_find_callback(const marisa_alpha_trie * h,const char * ptr,size_t length,int (* callback)(void *,marisa_alpha_uint32,size_t),void * first_arg_to_callback)265 marisa_alpha_status marisa_alpha_find_callback(const marisa_alpha_trie *h,
266     const char *ptr, size_t length,
267     int (*callback)(void *, marisa_alpha_uint32, size_t),
268     void *first_arg_to_callback) try {
269   if (h == NULL) {
270     return MARISA_ALPHA_HANDLE_ERROR;
271   } else if (callback == NULL) {
272     return MARISA_ALPHA_PARAM_ERROR;
273   }
274   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
275     h->trie.find_callback(ptr,
276         ::FindCallback(callback, first_arg_to_callback));
277   } else {
278     h->trie.find_callback(ptr, length,
279         ::FindCallback(callback, first_arg_to_callback));
280   }
281   return MARISA_ALPHA_OK;
282 } catch (const marisa_alpha::Exception &ex) {
283   return ex.status();
284 }
285 
marisa_alpha_predict(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_ids,size_t max_num_results,size_t * num_results)286 marisa_alpha_status marisa_alpha_predict(const marisa_alpha_trie *h,
287     const char *ptr, size_t length, marisa_alpha_uint32 *key_ids,
288     size_t max_num_results, size_t *num_results) {
289   return marisa_alpha_predict_breadth_first(h, ptr, length,
290       key_ids, max_num_results, num_results);
291 }
292 
marisa_alpha_predict_breadth_first(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_ids,size_t max_num_results,size_t * num_results)293 marisa_alpha_status marisa_alpha_predict_breadth_first(
294     const marisa_alpha_trie *h, const char *ptr, size_t length,
295     marisa_alpha_uint32 *key_ids, size_t max_num_results,
296     size_t *num_results) try {
297   if (h == NULL) {
298     return MARISA_ALPHA_HANDLE_ERROR;
299   } else if (num_results == NULL) {
300     return MARISA_ALPHA_PARAM_ERROR;
301   }
302   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
303     *num_results = h->trie.predict_breadth_first(
304         ptr, key_ids, NULL, max_num_results);
305   } else {
306     *num_results = h->trie.predict_breadth_first(
307         ptr, length, key_ids, NULL, max_num_results);
308   }
309   return MARISA_ALPHA_OK;
310 } catch (const marisa_alpha::Exception &ex) {
311   return ex.status();
312 }
313 
marisa_alpha_predict_depth_first(const marisa_alpha_trie * h,const char * ptr,size_t length,marisa_alpha_uint32 * key_ids,size_t max_num_results,size_t * num_results)314 marisa_alpha_status marisa_alpha_predict_depth_first(
315     const marisa_alpha_trie *h, const char *ptr, size_t length,
316     marisa_alpha_uint32 *key_ids, size_t max_num_results,
317     size_t *num_results) try {
318   if (h == NULL) {
319     return MARISA_ALPHA_HANDLE_ERROR;
320   } else if (num_results == NULL) {
321     return MARISA_ALPHA_PARAM_ERROR;
322   }
323   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
324     *num_results = h->trie.predict_depth_first(
325         ptr, key_ids, NULL, max_num_results);
326   } else {
327     *num_results = h->trie.predict_depth_first(
328         ptr, length, key_ids, NULL, max_num_results);
329   }
330   return MARISA_ALPHA_OK;
331 } catch (const marisa_alpha::Exception &ex) {
332   return ex.status();
333 }
334 
marisa_alpha_predict_callback(const marisa_alpha_trie * h,const char * ptr,size_t length,int (* callback)(void *,marisa_alpha_uint32,const char *,size_t),void * first_arg_to_callback)335 marisa_alpha_status marisa_alpha_predict_callback(const marisa_alpha_trie *h,
336     const char *ptr, size_t length,
337     int (*callback)(void *, marisa_alpha_uint32, const char *, size_t),
338     void *first_arg_to_callback) try {
339   if (h == NULL) {
340     return MARISA_ALPHA_HANDLE_ERROR;
341   } else if (callback == NULL) {
342     return MARISA_ALPHA_PARAM_ERROR;
343   }
344   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
345     h->trie.predict_callback(ptr,
346         ::PredictCallback(callback, first_arg_to_callback));
347   } else {
348     h->trie.predict_callback(ptr, length,
349         ::PredictCallback(callback, first_arg_to_callback));
350   }
351   return MARISA_ALPHA_OK;
352 } catch (const marisa_alpha::Exception &ex) {
353   return ex.status();
354 }
355 
marisa_alpha_get_num_tries(const marisa_alpha_trie * h)356 size_t marisa_alpha_get_num_tries(const marisa_alpha_trie *h) {
357   return (h != NULL) ? h->trie.num_tries() : 0;
358 }
359 
marisa_alpha_get_num_keys(const marisa_alpha_trie * h)360 size_t marisa_alpha_get_num_keys(const marisa_alpha_trie *h) {
361   return (h != NULL) ? h->trie.num_keys() : 0;
362 }
363 
marisa_alpha_get_num_nodes(const marisa_alpha_trie * h)364 size_t marisa_alpha_get_num_nodes(const marisa_alpha_trie *h) {
365   return (h != NULL) ? h->trie.num_nodes() : 0;
366 }
367 
marisa_alpha_get_total_size(const marisa_alpha_trie * h)368 size_t marisa_alpha_get_total_size(const marisa_alpha_trie *h) {
369   return (h != NULL) ? h->trie.total_size() : 0;
370 }
371 
marisa_alpha_clear(marisa_alpha_trie * h)372 marisa_alpha_status marisa_alpha_clear(marisa_alpha_trie *h) {
373   if (h == NULL) {
374     return MARISA_ALPHA_HANDLE_ERROR;
375   }
376   h->trie.clear();
377   h->mapper.clear();
378   return MARISA_ALPHA_OK;
379 }
380 
381 }  // extern "C"
382