1 #include <cstdlib>
2 #include <cstring>
3 #include <ctime>
4 #include <fstream>
5 #include <iostream>
6 #include <string>
7 #include <vector>
8 
9 #include <marisa.h>
10 
11 #include "cmdopt.h"
12 
13 namespace {
14 
15 int param_min_num_tries = 1;
16 int param_max_num_tries = 5;
17 marisa::TailMode param_tail_mode = MARISA_DEFAULT_TAIL;
18 marisa::NodeOrder param_node_order = MARISA_DEFAULT_ORDER;
19 marisa::CacheLevel param_cache_level = MARISA_DEFAULT_CACHE;
20 bool param_with_predict = true;
21 bool param_print_speed = true;
22 
23 class Clock {
24  public:
Clock()25   Clock() : cl_(std::clock()) {}
26 
reset()27   void reset() {
28     cl_ = std::clock();
29   }
30 
elasped() const31   double elasped() const {
32     std::clock_t cur = std::clock();
33     return static_cast<double>(cur - cl_) / static_cast<double>(CLOCKS_PER_SEC);
34   }
35 
36  private:
37   std::clock_t cl_;
38 };
39 
print_help(const char * cmd)40 void print_help(const char *cmd) {
41   std::cerr << "Usage: " << cmd << " [OPTION]... [FILE]...\n\n"
42       "Options:\n"
43       "  -N, --min-num-tries=[N]  limit the number of tries"
44       " [" << MARISA_MIN_NUM_TRIES << ", " << MARISA_MAX_NUM_TRIES
45       << "] (default: 1)\n"
46       "  -n, --max-num-tries=[N]  limit the number of tries"
47       " [" << MARISA_MIN_NUM_TRIES << ", " << MARISA_MAX_NUM_TRIES
48       << "] (default: 5)\n"
49       "  -t, --text-tail      build a dictionary with text TAIL (default)\n"
50       "  -b, --binary-tail    build a dictionary with binary TAIL\n"
51       "  -w, --weight-order   arrange siblings in weight order (default)\n"
52       "  -l, --label-order    arrange siblings in label order\n"
53       "  -c, --cache-level=[N]    specify the cache size"
54       " [1, 5] (default: 3)\n"
55       "  -P, --with-predict       include predictive search (default)\n"
56       "  -p, --without-predict    skip predictive search\n"
57       "  -S, --print-speed    print speed [1000 keys/s] (default)\n"
58       "  -s, --print-time     print time [ns/key]\n"
59       "  -h, --help           print this help\n"
60       << std::endl;
61 }
62 
print_config()63 void print_config() {
64   std::cout << "Number of tries: " << param_min_num_tries
65       << " - " << param_max_num_tries << std::endl;
66 
67   std::cout << "TAIL mode: ";
68   switch (param_tail_mode) {
69     case MARISA_TEXT_TAIL: {
70       std::cout << "Text mode" << std::endl;
71       break;
72     }
73     case MARISA_BINARY_TAIL: {
74       std::cout << "Binary mode" << std::endl;
75       break;
76     }
77   }
78 
79   std::cout << "Node order: ";
80   switch (param_node_order) {
81     case MARISA_LABEL_ORDER: {
82       std::cout << "Ascending label order" << std::endl;
83       break;
84     }
85     case MARISA_WEIGHT_ORDER: {
86       std::cout << "Descending weight order" << std::endl;
87       break;
88     }
89   }
90 
91   std::cout << "Cache level: ";
92   switch (param_cache_level) {
93     case MARISA_HUGE_CACHE: {
94       std::cout << "Huge cache" << std::endl;
95       break;
96     }
97     case MARISA_LARGE_CACHE: {
98       std::cout << "Large cache" << std::endl;
99       break;
100     }
101     case MARISA_NORMAL_CACHE: {
102       std::cout << "Normal cache" << std::endl;
103       break;
104     }
105     case MARISA_SMALL_CACHE: {
106       std::cout << "Small cache" << std::endl;
107       break;
108     }
109     case MARISA_TINY_CACHE: {
110       std::cout << "Tiny cache" << std::endl;
111       break;
112     }
113   }
114 }
115 
print_time_info(std::size_t num_keys,double elasped)116 void print_time_info(std::size_t num_keys, double elasped) {
117   if (param_print_speed) {
118     if (elasped == 0.0) {
119       std::printf(" %8s", "-");
120     } else {
121       std::printf(" %8.2f", static_cast<double>(num_keys) / elasped / 1000.0);
122     }
123   } else {
124     if ((elasped == 0.0) || (num_keys == 0)) {
125       std::printf(" %8s", "-");
126     } else {
127       std::printf(" %8.1f",
128           1000000000.0 * elasped / static_cast<double>(num_keys));
129     }
130   }
131 }
132 
read_keys(std::istream & input,marisa::Keyset * keyset,std::vector<float> * weights)133 void read_keys(std::istream &input, marisa::Keyset *keyset,
134     std::vector<float> *weights) {
135   std::string line;
136   while (std::getline(input, line)) {
137     const std::string::size_type delim_pos = line.find_last_of('\t');
138     float weight = 1.0F;
139     if (delim_pos != line.npos) {
140       char *end_of_value;
141       weight = (float)std::strtod(&line[delim_pos + 1], &end_of_value);
142       if (*end_of_value == '\0') {
143         line.resize(delim_pos);
144       }
145     }
146     keyset->push_back(line.c_str(), line.length());
147     weights->push_back(weight);
148   }
149 }
150 
read_keys(const char * const * args,std::size_t num_args,marisa::Keyset * keyset,std::vector<float> * weights)151 int read_keys(const char * const *args, std::size_t num_args,
152     marisa::Keyset *keyset, std::vector<float> *weights) {
153   if (num_args == 0) {
154     read_keys(std::cin, keyset, weights);
155   }
156   for (std::size_t i = 0; i < num_args; ++i) {
157     std::ifstream input_file(args[i], std::ios::binary);
158     if (!input_file) {
159       std::cerr << "error: failed to open: " << args[i] << std::endl;
160       return 10;
161     }
162     read_keys(input_file, keyset, weights);
163   }
164   std::cout << "Number of keys: " << keyset->size() << std::endl;
165   std::cout << "Total length: " << keyset->total_length() << std::endl;
166   return 0;
167 }
168 
benchmark_build(marisa::Keyset & keyset,const std::vector<float> & weights,int num_tries,marisa::Trie * trie)169 void benchmark_build(marisa::Keyset &keyset,
170     const std::vector<float> &weights, int num_tries, marisa::Trie *trie) {
171   for (std::size_t i = 0; i < keyset.size(); ++i) {
172     keyset[i].set_weight(weights[i]);
173   }
174   Clock cl;
175   trie->build(keyset, num_tries | param_tail_mode | param_node_order |
176       param_cache_level);
177   std::printf(" %10lu", (unsigned long)trie->io_size());
178   print_time_info(keyset.size(), cl.elasped());
179 }
180 
benchmark_lookup(const marisa::Trie & trie,const marisa::Keyset & keyset)181 void benchmark_lookup(const marisa::Trie &trie,
182     const marisa::Keyset &keyset) {
183   Clock cl;
184   marisa::Agent agent;
185   for (std::size_t i = 0; i < keyset.size(); ++i) {
186     agent.set_query(keyset[i].ptr(), keyset[i].length());
187     if (!trie.lookup(agent) || (agent.key().id() != keyset[i].id())) {
188       std::cerr << "error: lookup() failed" << std::endl;
189       return;
190     }
191   }
192   print_time_info(keyset.size(), cl.elasped());
193 }
194 
benchmark_reverse_lookup(const marisa::Trie & trie,const marisa::Keyset & keyset)195 void benchmark_reverse_lookup(const marisa::Trie &trie,
196     const marisa::Keyset &keyset) {
197   Clock cl;
198   marisa::Agent agent;
199   for (std::size_t i = 0; i < keyset.size(); ++i) {
200     agent.set_query(keyset[i].id());
201     trie.reverse_lookup(agent);
202     if ((agent.key().id() != keyset[i].id()) ||
203         (agent.key().length() != keyset[i].length()) ||
204         (std::memcmp(agent.key().ptr(), keyset[i].ptr(),
205             agent.key().length()) != 0)) {
206       std::cerr << "error: reverse_lookup() failed" << std::endl;
207       return;
208     }
209   }
210   print_time_info(keyset.size(), cl.elasped());
211 }
212 
benchmark_common_prefix_search(const marisa::Trie & trie,const marisa::Keyset & keyset)213 void benchmark_common_prefix_search(const marisa::Trie &trie,
214     const marisa::Keyset &keyset) {
215   Clock cl;
216   marisa::Agent agent;
217   for (std::size_t i = 0; i < keyset.size(); ++i) {
218     agent.set_query(keyset[i].ptr(), keyset[i].length());
219     while (trie.common_prefix_search(agent)) {
220       if (agent.key().id() > keyset[i].id()) {
221         std::cerr << "error: common_prefix_search() failed" << std::endl;
222         return;
223       }
224     }
225     if (agent.key().id() != keyset[i].id()) {
226       std::cerr << "error: common_prefix_search() failed" << std::endl;
227       return;
228     }
229   }
230   print_time_info(keyset.size(), cl.elasped());
231 }
232 
benchmark_predictive_search(const marisa::Trie & trie,const marisa::Keyset & keyset)233 void benchmark_predictive_search(const marisa::Trie &trie,
234     const marisa::Keyset &keyset) {
235   if (!param_with_predict) {
236     print_time_info(keyset.size(), 0.0);
237     return;
238   }
239 
240   Clock cl;
241   marisa::Agent agent;
242   for (std::size_t i = 0; i < keyset.size(); ++i) {
243     agent.set_query(keyset[i].ptr(), keyset[i].length());
244     if (!trie.predictive_search(agent) ||
245         (agent.key().id() != keyset[i].id())) {
246       std::cerr << "error: predictive_search() failed" << std::endl;
247       return;
248     }
249     while (trie.predictive_search(agent)) {
250       if (agent.key().id() <= keyset[i].id()) {
251         std::cerr << "error: predictive_search() failed" << std::endl;
252         return;
253       }
254     }
255   }
256   print_time_info(keyset.size(), cl.elasped());
257 }
258 
benchmark(marisa::Keyset & keyset,const std::vector<float> & weights,int num_tries)259 void benchmark(marisa::Keyset &keyset, const std::vector<float> &weights,
260     int num_tries) {
261   std::printf("%6d", num_tries);
262   marisa::Trie trie;
263   benchmark_build(keyset, weights, num_tries, &trie);
264   if (!trie.empty()) {
265     benchmark_lookup(trie, keyset);
266     benchmark_reverse_lookup(trie, keyset);
267     benchmark_common_prefix_search(trie, keyset);
268     benchmark_predictive_search(trie, keyset);
269   }
270   std::printf("\n");
271 }
272 
benchmark(const char * const * args,std::size_t num_args)273 int benchmark(const char * const *args, std::size_t num_args) try {
274   marisa::Keyset keyset;
275   std::vector<float> weights;
276   const int ret = read_keys(args, num_args, &keyset, &weights);
277   if (ret != 0) {
278     return ret;
279   }
280   std::printf("------+----------+--------+--------+"
281       "--------+--------+--------\n");
282   std::printf("%6s %10s %8s %8s %8s %8s %8s\n",
283       "#tries", "size", "build", "lookup", "reverse", "prefix", "predict");
284   std::printf("%6s %10s %8s %8s %8s %8s %8s\n",
285       "", "", "", "", "lookup", "search", "search");
286   if (param_print_speed) {
287     std::printf("%6s %10s %8s %8s %8s %8s %8s\n",
288         "", "[bytes]",
289         "[K/s]", "[K/s]", "[K/s]", "[K/s]", "[K/s]");
290   } else {
291     std::printf("%6s %10s %8s %8s %8s %8s %8s\n",
292         "", "[bytes]", "[ns]", "[ns]", "[ns]", "[ns]", "[ns]");
293   }
294   std::printf("------+----------+--------+--------+"
295       "--------+--------+--------\n");
296   for (int i = param_min_num_tries; i <= param_max_num_tries; ++i) {
297     benchmark(keyset, weights, i);
298   }
299   std::printf("------+----------+--------+--------+"
300       "--------+--------+--------\n");
301   return 0;
302 } catch (const marisa::Exception &ex) {
303   std::cerr << ex.what() << std::endl;
304   return -1;
305 }
306 
307 }  // namespace
308 
main(int argc,char * argv[])309 int main(int argc, char *argv[]) {
310   std::ios::sync_with_stdio(false);
311 
312   ::cmdopt_option long_options[] = {
313     { "min-num-tries", 1, NULL, 'N' },
314     { "max-num-tries", 1, NULL, 'n' },
315     { "text-tail", 0, NULL, 't' },
316     { "binary-tail", 0, NULL, 'b' },
317     { "weight-order", 0, NULL, 'w' },
318     { "label-order", 0, NULL, 'l' },
319     { "cache-level", 1, NULL, 'c' },
320     { "predict-on", 0, NULL, 'P' },
321     { "predict-off", 0, NULL, 'p' },
322     { "print-speed", 0, NULL, 'S' },
323     { "print-time", 0, NULL, 's' },
324     { "help", 0, NULL, 'h' },
325     { NULL, 0, NULL, 0 }
326   };
327   ::cmdopt_t cmdopt;
328   ::cmdopt_init(&cmdopt, argc, argv, "N:n:tbwlc:PpSsh", long_options);
329   int label;
330   while ((label = ::cmdopt_get(&cmdopt)) != -1) {
331     switch (label) {
332       case 'N': {
333         char *end_of_value;
334         const long value = std::strtol(cmdopt.optarg, &end_of_value, 10);
335         if ((*end_of_value != '\0') || (value <= 0) ||
336             (value > MARISA_MAX_NUM_TRIES)) {
337           std::cerr << "error: option `-n' with an invalid argument: "
338               << cmdopt.optarg << std::endl;
339           return 1;
340         }
341         param_min_num_tries = (int)value;
342         break;
343       }
344       case 'n': {
345         char *end_of_value;
346         const long value = std::strtol(cmdopt.optarg, &end_of_value, 10);
347         if ((*end_of_value != '\0') || (value <= 0) ||
348             (value > MARISA_MAX_NUM_TRIES)) {
349           std::cerr << "error: option `-n' with an invalid argument: "
350               << cmdopt.optarg << std::endl;
351           return 2;
352         }
353         param_max_num_tries = (int)value;
354         break;
355       }
356       case 't': {
357         param_tail_mode = MARISA_TEXT_TAIL;
358         break;
359       }
360       case 'b': {
361         param_tail_mode = MARISA_BINARY_TAIL;
362         break;
363       }
364       case 'w': {
365         param_node_order = MARISA_WEIGHT_ORDER;
366         break;
367       }
368       case 'l': {
369         param_node_order = MARISA_LABEL_ORDER;
370         break;
371       }
372       case 'c': {
373         char *end_of_value;
374         const long value = std::strtol(cmdopt.optarg, &end_of_value, 10);
375         if ((*end_of_value != '\0') || (value < 1) || (value > 5)) {
376           std::cerr << "error: option `-c' with an invalid argument: "
377               << cmdopt.optarg << std::endl;
378           return 3;
379         } else if (value == 1) {
380           param_cache_level = MARISA_TINY_CACHE;
381         } else if (value == 2) {
382           param_cache_level = MARISA_SMALL_CACHE;
383         } else if (value == 3) {
384           param_cache_level = MARISA_NORMAL_CACHE;
385         } else if (value == 4) {
386           param_cache_level = MARISA_LARGE_CACHE;
387         } else if (value == 5) {
388           param_cache_level = MARISA_HUGE_CACHE;
389         }
390         break;
391       }
392       case 'P': {
393         param_with_predict = true;
394         break;
395       }
396       case 'p': {
397         param_with_predict = false;
398         break;
399       }
400       case 'S': {
401         param_print_speed = true;
402         break;
403       }
404       case 's': {
405         param_print_speed = false;
406         break;
407       }
408       case 'h': {
409         print_help(argv[0]);
410         return 0;
411       }
412       default: {
413         return 1;
414       }
415     }
416   }
417   print_config();
418   return benchmark(cmdopt.argv + cmdopt.optind,
419       static_cast<std::size_t>(cmdopt.argc - cmdopt.optind));
420 }
421