1 #include <cstdlib>
2 #include <cstring>
3 #include <ctime>
4 #include <sstream>
5 
6 #include <marisa.h>
7 
8 #include "marisa-assert.h"
9 
10 namespace {
11 
TestEmptyTrie()12 void TestEmptyTrie() {
13   TEST_START();
14 
15   marisa::Trie trie;
16 
17   EXCEPT(trie.save("marisa-test.dat"), MARISA_STATE_ERROR);
18 #ifdef _MSC_VER
19   EXCEPT(trie.write(::_fileno(stdout)), MARISA_STATE_ERROR);
20 #else  // _MSC_VER
21   EXCEPT(trie.write(::fileno(stdout)), MARISA_STATE_ERROR);
22 #endif  // _MSC_VER
23   EXCEPT(std::cout << trie, MARISA_STATE_ERROR);
24   EXCEPT(marisa::fwrite(stdout, trie), MARISA_STATE_ERROR);
25 
26   marisa::Agent agent;
27 
28   EXCEPT(trie.lookup(agent), MARISA_STATE_ERROR);
29   EXCEPT(trie.reverse_lookup(agent), MARISA_STATE_ERROR);
30   EXCEPT(trie.common_prefix_search(agent), MARISA_STATE_ERROR);
31   EXCEPT(trie.predictive_search(agent), MARISA_STATE_ERROR);
32 
33   EXCEPT(trie.num_tries(), MARISA_STATE_ERROR);
34   EXCEPT(trie.num_keys(), MARISA_STATE_ERROR);
35   EXCEPT(trie.num_nodes(), MARISA_STATE_ERROR);
36 
37   EXCEPT(trie.tail_mode(), MARISA_STATE_ERROR);
38   EXCEPT(trie.node_order(), MARISA_STATE_ERROR);
39 
40   EXCEPT(trie.empty(), MARISA_STATE_ERROR);
41   EXCEPT(trie.size(), MARISA_STATE_ERROR);
42   EXCEPT(trie.total_size(), MARISA_STATE_ERROR);
43   EXCEPT(trie.io_size(), MARISA_STATE_ERROR);
44 
45   marisa::Keyset keyset;
46   trie.build(keyset);
47 
48   ASSERT(!trie.lookup(agent));
49   EXCEPT(trie.reverse_lookup(agent), MARISA_BOUND_ERROR);
50   ASSERT(!trie.common_prefix_search(agent));
51   ASSERT(!trie.predictive_search(agent));
52 
53   ASSERT(trie.num_tries() == 1);
54   ASSERT(trie.num_keys() == 0);
55   ASSERT(trie.num_nodes() == 1);
56 
57   ASSERT(trie.tail_mode() == MARISA_DEFAULT_TAIL);
58   ASSERT(trie.node_order() == MARISA_DEFAULT_ORDER);
59 
60   ASSERT(trie.empty());
61   ASSERT(trie.size() == 0);
62   ASSERT(trie.total_size() != 0);
63   ASSERT(trie.io_size() != 0);
64 
65   keyset.push_back("");
66   trie.build(keyset);
67 
68   ASSERT(trie.lookup(agent));
69   trie.reverse_lookup(agent);
70   ASSERT(trie.common_prefix_search(agent));
71   ASSERT(!trie.common_prefix_search(agent));
72   ASSERT(trie.predictive_search(agent));
73   ASSERT(!trie.predictive_search(agent));
74 
75   ASSERT(trie.num_keys() == 1);
76   ASSERT(trie.num_nodes() == 1);
77 
78   ASSERT(!trie.empty());
79   ASSERT(trie.size() == 1);
80   ASSERT(trie.total_size() != 0);
81   ASSERT(trie.io_size() != 0);
82 
83   TEST_END();
84 }
85 
TestTinyTrie()86 void TestTinyTrie() {
87   TEST_START();
88 
89   marisa::Keyset keyset;
90   keyset.push_back("bach");
91   keyset.push_back("bet");
92   keyset.push_back("chat");
93   keyset.push_back("check");
94   keyset.push_back("check");
95 
96   marisa::Trie trie;
97   trie.build(keyset, 1);
98 
99   ASSERT(trie.num_tries() == 1);
100   ASSERT(trie.num_keys() == 4);
101   ASSERT(trie.num_nodes() == 7);
102 
103   ASSERT(trie.tail_mode() == MARISA_DEFAULT_TAIL);
104   ASSERT(trie.node_order() == MARISA_DEFAULT_ORDER);
105 
106   ASSERT(keyset[0].id() == 2);
107   ASSERT(keyset[1].id() == 3);
108   ASSERT(keyset[2].id() == 1);
109   ASSERT(keyset[3].id() == 0);
110   ASSERT(keyset[4].id() == 0);
111 
112   marisa::Agent agent;
113   for (std::size_t i = 0; i < keyset.size(); ++i) {
114     agent.set_query(keyset[i].ptr(), keyset[i].length());
115     ASSERT(trie.lookup(agent));
116     ASSERT(agent.key().id() == keyset[i].id());
117 
118     agent.set_query(keyset[i].id());
119     trie.reverse_lookup(agent);
120     ASSERT(agent.key().length() == keyset[i].length());
121     ASSERT(std::memcmp(agent.key().ptr(), keyset[i].ptr(),
122         agent.key().length()) == 0);
123   }
124 
125   agent.set_query("be");
126   ASSERT(!trie.common_prefix_search(agent));
127   agent.set_query("beX");
128   ASSERT(!trie.common_prefix_search(agent));
129   agent.set_query("bet");
130   ASSERT(trie.common_prefix_search(agent));
131   ASSERT(!trie.common_prefix_search(agent));
132   agent.set_query("betX");
133   ASSERT(trie.common_prefix_search(agent));
134   ASSERT(!trie.common_prefix_search(agent));
135 
136   agent.set_query("chatX");
137   ASSERT(!trie.predictive_search(agent));
138   agent.set_query("chat");
139   ASSERT(trie.predictive_search(agent));
140   ASSERT(agent.key().length() == 4);
141   ASSERT(!trie.predictive_search(agent));
142 
143   agent.set_query("cha");
144   ASSERT(trie.predictive_search(agent));
145   ASSERT(agent.key().length() == 4);
146   ASSERT(!trie.predictive_search(agent));
147 
148   agent.set_query("c");
149   ASSERT(trie.predictive_search(agent));
150   ASSERT(agent.key().length() == 5);
151   ASSERT(std::memcmp(agent.key().ptr(), "check", 5) == 0);
152   ASSERT(trie.predictive_search(agent));
153   ASSERT(agent.key().length() == 4);
154   ASSERT(std::memcmp(agent.key().ptr(), "chat", 4) == 0);
155   ASSERT(!trie.predictive_search(agent));
156 
157   agent.set_query("ch");
158   ASSERT(trie.predictive_search(agent));
159   ASSERT(agent.key().length() == 5);
160   ASSERT(std::memcmp(agent.key().ptr(), "check", 5) == 0);
161   ASSERT(trie.predictive_search(agent));
162   ASSERT(agent.key().length() == 4);
163   ASSERT(std::memcmp(agent.key().ptr(), "chat", 4) == 0);
164   ASSERT(!trie.predictive_search(agent));
165 
166   trie.build(keyset, 1 | MARISA_LABEL_ORDER);
167 
168   ASSERT(trie.num_tries() == 1);
169   ASSERT(trie.num_keys() == 4);
170   ASSERT(trie.num_nodes() == 7);
171 
172   ASSERT(trie.tail_mode() == MARISA_DEFAULT_TAIL);
173   ASSERT(trie.node_order() == MARISA_LABEL_ORDER);
174 
175   ASSERT(keyset[0].id() == 0);
176   ASSERT(keyset[1].id() == 1);
177   ASSERT(keyset[2].id() == 2);
178   ASSERT(keyset[3].id() == 3);
179   ASSERT(keyset[4].id() == 3);
180 
181   for (std::size_t i = 0; i < keyset.size(); ++i) {
182     agent.set_query(keyset[i].ptr(), keyset[i].length());
183     ASSERT(trie.lookup(agent));
184     ASSERT(agent.key().id() == keyset[i].id());
185 
186     agent.set_query(keyset[i].id());
187     trie.reverse_lookup(agent);
188     ASSERT(agent.key().length() == keyset[i].length());
189     ASSERT(std::memcmp(agent.key().ptr(), keyset[i].ptr(),
190         agent.key().length()) == 0);
191   }
192 
193   agent.set_query("");
194   for (std::size_t i = 0; i < trie.size(); ++i) {
195     ASSERT(trie.predictive_search(agent));
196     ASSERT(agent.key().id() == i);
197   }
198   ASSERT(!trie.predictive_search(agent));
199 
200   TEST_END();
201 }
202 
MakeKeyset(std::size_t num_keys,marisa::TailMode tail_mode,marisa::Keyset * keyset)203 void MakeKeyset(std::size_t num_keys, marisa::TailMode tail_mode,
204     marisa::Keyset *keyset) {
205   char key_buf[16];
206   for (std::size_t i = 0; i < num_keys; ++i) {
207     const std::size_t length =
208         static_cast<std::size_t>(std::rand()) % sizeof(key_buf);
209     for (std::size_t j = 0; j < length; ++j) {
210       key_buf[j] = (char)(std::rand() % 10);
211       if (tail_mode == MARISA_TEXT_TAIL) {
212         key_buf[j] = static_cast<char>(key_buf[j] + '0');
213       }
214     }
215     keyset->push_back(key_buf, length);
216   }
217 }
218 
TestLookup(const marisa::Trie & trie,const marisa::Keyset & keyset)219 void TestLookup(const marisa::Trie &trie, const marisa::Keyset &keyset) {
220   marisa::Agent agent;
221   for (std::size_t i = 0; i < keyset.size(); ++i) {
222     agent.set_query(keyset[i].ptr(), keyset[i].length());
223     ASSERT(trie.lookup(agent));
224     ASSERT(agent.key().id() == keyset[i].id());
225 
226     agent.set_query(keyset[i].id());
227     trie.reverse_lookup(agent);
228     ASSERT(agent.key().length() == keyset[i].length());
229     ASSERT(std::memcmp(agent.key().ptr(), keyset[i].ptr(),
230         agent.key().length()) == 0);
231   }
232 }
233 
TestCommonPrefixSearch(const marisa::Trie & trie,const marisa::Keyset & keyset)234 void TestCommonPrefixSearch(const marisa::Trie &trie,
235     const marisa::Keyset &keyset) {
236   marisa::Agent agent;
237   for (std::size_t i = 0; i < keyset.size(); ++i) {
238     agent.set_query(keyset[i].ptr(), keyset[i].length());
239     ASSERT(trie.common_prefix_search(agent));
240     ASSERT(agent.key().id() <= keyset[i].id());
241     while (trie.common_prefix_search(agent)) {
242       ASSERT(agent.key().id() <= keyset[i].id());
243     }
244     ASSERT(agent.key().id() == keyset[i].id());
245   }
246 }
247 
TestPredictiveSearch(const marisa::Trie & trie,const marisa::Keyset & keyset)248 void TestPredictiveSearch(const marisa::Trie &trie,
249     const marisa::Keyset &keyset) {
250   marisa::Agent agent;
251   for (std::size_t i = 0; i < keyset.size(); ++i) {
252     agent.set_query(keyset[i].ptr(), keyset[i].length());
253     ASSERT(trie.predictive_search(agent));
254     ASSERT(agent.key().id() == keyset[i].id());
255     while (trie.predictive_search(agent)) {
256       ASSERT(agent.key().id() > keyset[i].id());
257     }
258   }
259 }
260 
TestTrie(int num_tries,marisa::TailMode tail_mode,marisa::NodeOrder node_order,marisa::Keyset & keyset)261 void TestTrie(int num_tries, marisa::TailMode tail_mode,
262     marisa::NodeOrder node_order, marisa::Keyset &keyset) {
263   for (std::size_t i = 0; i < keyset.size(); ++i) {
264     keyset[i].set_weight(1.0F);
265   }
266 
267   marisa::Trie trie;
268   trie.build(keyset, num_tries | tail_mode | node_order);
269 
270   ASSERT(trie.num_tries() == (std::size_t)num_tries);
271   ASSERT(trie.num_keys() <= keyset.size());
272 
273   ASSERT(trie.tail_mode() == tail_mode);
274   ASSERT(trie.node_order() == node_order);
275 
276   TestLookup(trie, keyset);
277   TestCommonPrefixSearch(trie, keyset);
278   TestPredictiveSearch(trie, keyset);
279 
280   trie.save("marisa-test.dat");
281 
282   trie.clear();
283   trie.load("marisa-test.dat");
284 
285   ASSERT(trie.num_tries() == (std::size_t)num_tries);
286   ASSERT(trie.num_keys() <= keyset.size());
287 
288   ASSERT(trie.tail_mode() == tail_mode);
289   ASSERT(trie.node_order() == node_order);
290 
291   TestLookup(trie, keyset);
292 
293   {
294     std::FILE *file;
295 #ifdef _MSC_VER
296     ASSERT(::fopen_s(&file, "marisa-test.dat", "wb") == 0);
297 #else  // _MSC_VER
298     file = std::fopen("marisa-test.dat", "wb");
299     ASSERT(file != NULL);
300 #endif  // _MSC_VER
301     marisa::fwrite(file, trie);
302     std::fclose(file);
303     trie.clear();
304 #ifdef _MSC_VER
305     ASSERT(::fopen_s(&file, "marisa-test.dat", "rb") == 0);
306 #else  // _MSC_VER
307     file = std::fopen("marisa-test.dat", "rb");
308     ASSERT(file != NULL);
309 #endif  // _MSC_VER
310     marisa::fread(file, &trie);
311     std::fclose(file);
312   }
313 
314   ASSERT(trie.num_tries() == (std::size_t)num_tries);
315   ASSERT(trie.num_keys() <= keyset.size());
316 
317   ASSERT(trie.tail_mode() == tail_mode);
318   ASSERT(trie.node_order() == node_order);
319 
320   TestLookup(trie, keyset);
321 
322   trie.clear();
323   trie.mmap("marisa-test.dat");
324 
325   ASSERT(trie.num_tries() == (std::size_t)num_tries);
326   ASSERT(trie.num_keys() <= keyset.size());
327 
328   ASSERT(trie.tail_mode() == tail_mode);
329   ASSERT(trie.node_order() == node_order);
330 
331   TestLookup(trie, keyset);
332 
333   {
334     std::stringstream stream;
335     stream << trie;
336     trie.clear();
337     stream >> trie;
338   }
339 
340   ASSERT(trie.num_tries() == (std::size_t)num_tries);
341   ASSERT(trie.num_keys() <= keyset.size());
342 
343   ASSERT(trie.tail_mode() == tail_mode);
344   ASSERT(trie.node_order() == node_order);
345 
346   TestLookup(trie, keyset);
347 }
348 
TestTrie(marisa::TailMode tail_mode,marisa::NodeOrder node_order,marisa::Keyset & keyset)349 void TestTrie(marisa::TailMode tail_mode, marisa::NodeOrder node_order,
350     marisa::Keyset &keyset) {
351   TEST_START();
352   std::cout << ((tail_mode == MARISA_TEXT_TAIL) ? "TEXT" : "BINARY") << ", ";
353   std::cout << ((node_order == MARISA_WEIGHT_ORDER) ?
354       "WEIGHT" : "LABEL") << ": ";
355 
356   for (int i = 1; i < 5; ++i) {
357     TestTrie(i, tail_mode, node_order, keyset);
358   }
359 
360   TEST_END();
361 }
362 
TestTrie(marisa::TailMode tail_mode)363 void TestTrie(marisa::TailMode tail_mode) {
364   marisa::Keyset keyset;
365   MakeKeyset(1000, tail_mode, &keyset);
366 
367   TestTrie(tail_mode, MARISA_WEIGHT_ORDER, keyset);
368   TestTrie(tail_mode, MARISA_LABEL_ORDER, keyset);
369 }
370 
TestTrie()371 void TestTrie() {
372   TestTrie(MARISA_TEXT_TAIL);
373   TestTrie(MARISA_BINARY_TAIL);
374 }
375 
376 }  // namespace
377 
main()378 int main() try {
379   std::srand((unsigned int)std::time(NULL));
380 
381   TestEmptyTrie();
382   TestTinyTrie();
383   TestTrie();
384 
385   return 0;
386 } catch (const marisa::Exception &ex) {
387   std::cerr << ex.what() << std::endl;
388   throw;
389 }
390