1 #include "trie.h"
2 
3 extern "C" {
4 
5 namespace {
6 
7 class FindCallback {
8  public:
9   typedef int (*Func)(void *, marisa_uint32, size_t);
10 
FindCallback(Func func,void * first_arg)11   FindCallback(Func func, void *first_arg)
12       : func_(func), first_arg_(first_arg) {}
FindCallback(const FindCallback & callback)13   FindCallback(const FindCallback &callback)
14       : func_(callback.func_), first_arg_(callback.first_arg_) {}
15 
operator ()(marisa::UInt32 key_id,std::size_t key_length) const16   bool operator()(marisa::UInt32 key_id, std::size_t key_length) const {
17     return func_(first_arg_, key_id, key_length) != 0;
18   }
19 
20  private:
21   Func func_;
22   void *first_arg_;
23 
24   // Disallows assignment.
25   FindCallback &operator=(const FindCallback &);
26 };
27 
28 class PredictCallback {
29  public:
30   typedef int (*Func)(void *, marisa_uint32, const char *, size_t);
31 
PredictCallback(Func func,void * first_arg)32   PredictCallback(Func func, void *first_arg)
33       : func_(func), first_arg_(first_arg) {}
PredictCallback(const PredictCallback & callback)34   PredictCallback(const PredictCallback &callback)
35       : func_(callback.func_), first_arg_(callback.first_arg_) {}
36 
operator ()(marisa::UInt32 key_id,const std::string & key) const37   bool operator()(marisa::UInt32 key_id, const std::string &key) const {
38     return func_(first_arg_, key_id, key.c_str(), key.length()) != 0;
39   }
40 
41  private:
42   Func func_;
43   void *first_arg_;
44 
45   // Disallows assignment.
46   PredictCallback &operator=(const PredictCallback &);
47 };
48 
49 }  // namespace
50 
51 struct marisa_trie_ {
52  public:
marisa_trie_marisa_trie_53   marisa_trie_() : trie(), mapper() {}
54 
55   marisa::Trie trie;
56   marisa::Mapper mapper;
57 
58  private:
59   // Disallows copy and assignment.
60   marisa_trie_(const marisa_trie_ &);
61   marisa_trie_ &operator=(const marisa_trie_ &);
62 };
63 
marisa_init(marisa_trie ** h)64 marisa_status marisa_init(marisa_trie **h) {
65   if ((h == NULL) || (*h != NULL)) {
66     return MARISA_HANDLE_ERROR;
67   }
68   *h = new (std::nothrow) marisa_trie_();
69   return (*h != NULL) ? MARISA_OK : MARISA_MEMORY_ERROR;
70 }
71 
marisa_end(marisa_trie * h)72 marisa_status marisa_end(marisa_trie *h) {
73   if (h == NULL) {
74     return MARISA_HANDLE_ERROR;
75   }
76   delete h;
77   return MARISA_OK;
78 }
79 
marisa_build(marisa_trie * h,const char * const * keys,size_t num_keys,const size_t * key_lengths,const double * key_weights,marisa_uint32 * key_ids,int flags)80 marisa_status marisa_build(marisa_trie *h, const char * const *keys,
81     size_t num_keys, const size_t *key_lengths, const double *key_weights,
82     marisa_uint32 *key_ids, int flags) {
83   if (h == NULL) {
84     return MARISA_HANDLE_ERROR;
85   }
86   h->trie.build(keys, num_keys, key_lengths, key_weights, key_ids, flags);
87   h->mapper.clear();
88   return MARISA_OK;
89 }
90 
marisa_mmap(marisa_trie * h,const char * filename,long offset,int whence)91 marisa_status marisa_mmap(marisa_trie *h, const char *filename,
92     long offset, int whence) {
93   if (h == NULL) {
94     return MARISA_HANDLE_ERROR;
95   }
96   h->trie.mmap(&h->mapper, filename, offset, whence);
97   return MARISA_OK;
98 }
99 
marisa_map(marisa_trie * h,const void * ptr,size_t size)100 marisa_status marisa_map(marisa_trie *h, const void *ptr, size_t size) {
101   if (h == NULL) {
102     return MARISA_HANDLE_ERROR;
103   }
104   h->trie.map(ptr, size);
105   h->mapper.clear();
106   return MARISA_OK;
107 }
108 
marisa_load(marisa_trie * h,const char * filename,long offset,int whence)109 marisa_status marisa_load(marisa_trie *h, const char *filename,
110     long offset, int whence) {
111   if (h == NULL) {
112     return MARISA_HANDLE_ERROR;
113   }
114   h->trie.load(filename, offset, whence);
115   h->mapper.clear();
116   return MARISA_OK;
117 }
118 
marisa_fread(marisa_trie * h,FILE * file)119 marisa_status marisa_fread(marisa_trie *h, FILE *file) {
120   if (h == NULL) {
121     return MARISA_HANDLE_ERROR;
122   }
123   h->trie.fread(file);
124   h->mapper.clear();
125   return MARISA_OK;
126 }
127 
marisa_read(marisa_trie * h,int fd)128 marisa_status marisa_read(marisa_trie *h, int fd) {
129   if (h == NULL) {
130     return MARISA_HANDLE_ERROR;
131   }
132   h->trie.read(fd);
133   h->mapper.clear();
134   return MARISA_OK;
135 }
136 
marisa_save(const marisa_trie * h,const char * filename,int trunc_flag,long offset,int whence)137 marisa_status marisa_save(const marisa_trie *h, const char *filename,
138     int trunc_flag, long offset, int whence) {
139   if (h == NULL) {
140     return MARISA_HANDLE_ERROR;
141   }
142   h->trie.save(filename, trunc_flag != 0, offset, whence);
143   return MARISA_OK;
144 }
145 
marisa_fwrite(const marisa_trie * h,FILE * file)146 marisa_status marisa_fwrite(const marisa_trie *h, FILE *file) {
147   if (h == NULL) {
148     return MARISA_HANDLE_ERROR;
149   }
150   h->trie.fwrite(file);
151   return MARISA_OK;
152 }
153 
marisa_write(const marisa_trie * h,int fd)154 marisa_status marisa_write(const marisa_trie *h, int fd) {
155   if (h == NULL) {
156     return MARISA_HANDLE_ERROR;
157   }
158   h->trie.write(fd);
159   return MARISA_OK;
160 }
161 
marisa_restore(const marisa_trie * h,marisa_uint32 key_id,char * key_buf,size_t key_buf_size,size_t * key_length)162 marisa_status marisa_restore(const marisa_trie *h, marisa_uint32 key_id,
163     char *key_buf, size_t key_buf_size, size_t *key_length) {
164   if (h == NULL) {
165     return MARISA_HANDLE_ERROR;
166   } else if (key_length == NULL) {
167     return MARISA_PARAM_ERROR;
168   }
169   *key_length = h->trie.restore(key_id, key_buf, key_buf_size);
170   return MARISA_OK;
171 }
172 
marisa_lookup(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_id)173 marisa_status marisa_lookup(const marisa_trie *h,
174     const char *ptr, size_t length, marisa_uint32 *key_id) {
175   if (h == NULL) {
176     return MARISA_HANDLE_ERROR;
177   } else if (key_id == NULL) {
178     return MARISA_PARAM_ERROR;
179   }
180   if (length == MARISA_ZERO_TERMINATED) {
181     *key_id = h->trie.lookup(ptr);
182   } else {
183     *key_id = h->trie.lookup(ptr, length);
184   }
185   return MARISA_OK;
186 }
187 
marisa_find(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_ids,size_t * key_lengths,size_t max_num_results,size_t * num_results)188 marisa_status marisa_find(const marisa_trie *h,
189     const char *ptr, size_t length,
190     marisa_uint32 *key_ids, size_t *key_lengths,
191     size_t max_num_results, size_t *num_results) {
192   if (h == NULL) {
193     return MARISA_HANDLE_ERROR;
194   } else if (num_results == NULL) {
195     return MARISA_PARAM_ERROR;
196   }
197   if (length == MARISA_ZERO_TERMINATED) {
198     *num_results = h->trie.find(ptr, key_ids, key_lengths, max_num_results);
199   } else {
200     *num_results = h->trie.find(ptr, length,
201         key_ids, key_lengths, max_num_results);
202   }
203   return MARISA_OK;
204 }
205 
marisa_find_first(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_id,size_t * key_length)206 marisa_status marisa_find_first(const marisa_trie *h,
207     const char *ptr, size_t length,
208     marisa_uint32 *key_id, size_t *key_length) {
209   if (h == NULL) {
210     return MARISA_HANDLE_ERROR;
211   } else if (key_id == NULL) {
212     return MARISA_PARAM_ERROR;
213   }
214   if (length == MARISA_ZERO_TERMINATED) {
215     *key_id = h->trie.find_first(ptr, key_length);
216   } else {
217     *key_id = h->trie.find_first(ptr, length, key_length);
218   }
219   return MARISA_OK;
220 }
221 
marisa_find_last(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_id,size_t * key_length)222 marisa_status marisa_find_last(const marisa_trie *h,
223     const char *ptr, size_t length,
224     marisa_uint32 *key_id, size_t *key_length) {
225   if (h == NULL) {
226     return MARISA_HANDLE_ERROR;
227   } else if (key_id == NULL) {
228     return MARISA_PARAM_ERROR;
229   }
230   if (length == MARISA_ZERO_TERMINATED) {
231     *key_id = h->trie.find_last(ptr, key_length);
232   } else {
233     *key_id = h->trie.find_last(ptr, length, key_length);
234   }
235   return MARISA_OK;
236 }
237 
marisa_find_callback(const marisa_trie * h,const char * ptr,size_t length,int (* callback)(void *,marisa_uint32,size_t),void * first_arg_to_callback)238 marisa_status marisa_find_callback(const marisa_trie *h,
239     const char *ptr, size_t length,
240     int (*callback)(void *, marisa_uint32, size_t),
241     void *first_arg_to_callback) {
242   if (h == NULL) {
243     return MARISA_HANDLE_ERROR;
244   } else if (callback == NULL) {
245     return MARISA_PARAM_ERROR;
246   }
247   if (length == MARISA_ZERO_TERMINATED) {
248     h->trie.find_callback(ptr,
249         ::FindCallback(callback, first_arg_to_callback));
250   } else {
251     h->trie.find_callback(ptr, length,
252         ::FindCallback(callback, first_arg_to_callback));
253   }
254   return MARISA_OK;
255 }
256 
marisa_predict(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_ids,size_t max_num_results,size_t * num_results)257 marisa_status marisa_predict(const marisa_trie *h,
258     const char *ptr, size_t length, marisa_uint32 *key_ids,
259     size_t max_num_results, size_t *num_results) {
260   return marisa_predict_breadth_first(h, ptr, length,
261       key_ids, max_num_results, num_results);
262 }
263 
marisa_predict_breadth_first(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_ids,size_t max_num_results,size_t * num_results)264 marisa_status marisa_predict_breadth_first(const marisa_trie *h,
265     const char *ptr, size_t length, marisa_uint32 *key_ids,
266     size_t max_num_results, size_t *num_results) {
267   if (h == NULL) {
268     return MARISA_HANDLE_ERROR;
269   } else if (num_results == NULL) {
270     return MARISA_PARAM_ERROR;
271   }
272   if (length == MARISA_ZERO_TERMINATED) {
273     *num_results = h->trie.predict_breadth_first(
274         ptr, key_ids, NULL, max_num_results);
275   } else {
276     *num_results = h->trie.predict_breadth_first(
277         ptr, length, key_ids, NULL, max_num_results);
278   }
279   return MARISA_OK;
280 }
281 
marisa_predict_depth_first(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_ids,size_t max_num_results,size_t * num_results)282 marisa_status marisa_predict_depth_first(const marisa_trie *h,
283     const char *ptr, size_t length, marisa_uint32 *key_ids,
284     size_t max_num_results, size_t *num_results) {
285   if (h == NULL) {
286     return MARISA_HANDLE_ERROR;
287   } else if (num_results == NULL) {
288     return MARISA_PARAM_ERROR;
289   }
290   if (length == MARISA_ZERO_TERMINATED) {
291     *num_results = h->trie.predict_depth_first(
292         ptr, key_ids, NULL, max_num_results);
293   } else {
294     *num_results = h->trie.predict_depth_first(
295         ptr, length, key_ids, NULL, max_num_results);
296   }
297   return MARISA_OK;
298 }
299 
marisa_predict_callback(const marisa_trie * h,const char * ptr,size_t length,int (* callback)(void *,marisa_uint32,const char *,size_t),void * first_arg_to_callback)300 marisa_status marisa_predict_callback(const marisa_trie *h,
301     const char *ptr, size_t length,
302     int (*callback)(void *, marisa_uint32, const char *, size_t),
303     void *first_arg_to_callback) {
304   if (h == NULL) {
305     return MARISA_HANDLE_ERROR;
306   } else if (callback == NULL) {
307     return MARISA_PARAM_ERROR;
308   }
309   if (length == MARISA_ZERO_TERMINATED) {
310     h->trie.predict_callback(ptr,
311         ::PredictCallback(callback, first_arg_to_callback));
312   } else {
313     h->trie.predict_callback(ptr, length,
314         ::PredictCallback(callback, first_arg_to_callback));
315   }
316   return MARISA_OK;
317 }
318 
marisa_get_num_tries(const marisa_trie * h)319 size_t marisa_get_num_tries(const marisa_trie *h) {
320   return (h != NULL) ? h->trie.num_tries() : 0;
321 }
322 
marisa_get_num_keys(const marisa_trie * h)323 size_t marisa_get_num_keys(const marisa_trie *h) {
324   return (h != NULL) ? h->trie.num_keys() : 0;
325 }
326 
marisa_get_num_nodes(const marisa_trie * h)327 size_t marisa_get_num_nodes(const marisa_trie *h) {
328   return (h != NULL) ? h->trie.num_nodes() : 0;
329 }
330 
marisa_get_total_size(const marisa_trie * h)331 size_t marisa_get_total_size(const marisa_trie *h) {
332   return (h != NULL) ? h->trie.total_size() : 0;
333 }
334 
marisa_clear(marisa_trie * h)335 marisa_status marisa_clear(marisa_trie *h) {
336   if (h == NULL) {
337     return MARISA_HANDLE_ERROR;
338   }
339   h->trie.clear();
340   h->mapper.clear();
341   return MARISA_OK;
342 }
343 
344 }  // extern "C"
345