1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 
13 #include "tensorflow/lite/tools/command_line_flags.h"
14 
15 #include <algorithm>
16 #include <cstring>
17 #include <iomanip>
18 #include <numeric>
19 #include <sstream>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "tensorflow/lite/tools/logging.h"
26 
27 namespace tflite {
28 namespace {
29 
30 template <typename T>
ToString(T val)31 std::string ToString(T val) {
32   std::ostringstream stream;
33   stream << val;
34   return stream.str();
35 }
36 
ParseFlag(const std::string & arg,const std::string & flag,bool positional,const std::function<bool (const std::string &)> & parse_func,bool * value_parsing_ok)37 bool ParseFlag(const std::string& arg, const std::string& flag, bool positional,
38                const std::function<bool(const std::string&)>& parse_func,
39                bool* value_parsing_ok) {
40   if (positional) {
41     *value_parsing_ok = parse_func(arg);
42     return true;
43   }
44   *value_parsing_ok = true;
45   std::string flag_prefix = "--" + flag + "=";
46   if (arg.find(flag_prefix) != 0) {
47     return false;
48   }
49   bool has_value = arg.size() >= flag_prefix.size();
50   *value_parsing_ok = has_value;
51   if (has_value) {
52     *value_parsing_ok = parse_func(arg.substr(flag_prefix.size()));
53   }
54   return true;
55 }
56 
57 template <typename T>
ParseFlag(const std::string & flag_value,const std::function<void (const T &)> & hook)58 bool ParseFlag(const std::string& flag_value,
59                const std::function<void(const T&)>& hook) {
60   std::istringstream stream(flag_value);
61   T read_value;
62   stream >> read_value;
63   if (!stream.eof() && !stream.good()) {
64     return false;
65   }
66   hook(read_value);
67   return true;
68 }
69 
ParseBoolFlag(const std::string & flag_value,const std::function<void (const bool &)> & hook)70 bool ParseBoolFlag(const std::string& flag_value,
71                    const std::function<void(const bool&)>& hook) {
72   if (flag_value != "true" && flag_value != "false" && flag_value != "0" &&
73       flag_value != "1") {
74     return false;
75   }
76 
77   hook(flag_value == "true" || flag_value == "1");
78   return true;
79 }
80 }  // namespace
81 
Flag(const char * name,const std::function<void (const int32_t &)> & hook,int32_t default_value,const std::string & usage_text,FlagType flag_type)82 Flag::Flag(const char* name, const std::function<void(const int32_t&)>& hook,
83            int32_t default_value, const std::string& usage_text,
84            FlagType flag_type)
85     : name_(name),
86       type_(TYPE_INT32),
87       value_hook_([hook](const std::string& flag_value) {
88         return ParseFlag<int32_t>(flag_value, hook);
89       }),
90       default_for_display_(ToString(default_value)),
91       usage_text_(usage_text),
92       flag_type_(flag_type) {}
93 
Flag(const char * name,const std::function<void (const int64_t &)> & hook,int64_t default_value,const std::string & usage_text,FlagType flag_type)94 Flag::Flag(const char* name, const std::function<void(const int64_t&)>& hook,
95            int64_t default_value, const std::string& usage_text,
96            FlagType flag_type)
97     : name_(name),
98       type_(TYPE_INT64),
99       value_hook_([hook](const std::string& flag_value) {
100         return ParseFlag<int64_t>(flag_value, hook);
101       }),
102       default_for_display_(ToString(default_value)),
103       usage_text_(usage_text),
104       flag_type_(flag_type) {}
105 
Flag(const char * name,const std::function<void (const float &)> & hook,float default_value,const std::string & usage_text,FlagType flag_type)106 Flag::Flag(const char* name, const std::function<void(const float&)>& hook,
107            float default_value, const std::string& usage_text,
108            FlagType flag_type)
109     : name_(name),
110       type_(TYPE_FLOAT),
111       value_hook_([hook](const std::string& flag_value) {
112         return ParseFlag<float>(flag_value, hook);
113       }),
114       default_for_display_(ToString(default_value)),
115       usage_text_(usage_text),
116       flag_type_(flag_type) {}
117 
Flag(const char * name,const std::function<void (const bool &)> & hook,bool default_value,const std::string & usage_text,FlagType flag_type)118 Flag::Flag(const char* name, const std::function<void(const bool&)>& hook,
119            bool default_value, const std::string& usage_text,
120            FlagType flag_type)
121     : name_(name),
122       type_(TYPE_BOOL),
123       value_hook_([hook](const std::string& flag_value) {
124         return ParseBoolFlag(flag_value, hook);
125       }),
126       default_for_display_(default_value ? "true" : "false"),
127       usage_text_(usage_text),
128       flag_type_(flag_type) {}
129 
Flag(const char * name,const std::function<void (const std::string &)> & hook,const std::string & default_value,const std::string & usage_text,FlagType flag_type)130 Flag::Flag(const char* name,
131            const std::function<void(const std::string&)>& hook,
132            const std::string& default_value, const std::string& usage_text,
133            FlagType flag_type)
134     : name_(name),
135       type_(TYPE_STRING),
136       value_hook_([hook](const std::string& flag_value) {
137         hook(flag_value);
138         return true;
139       }),
140       default_for_display_(default_value),
141       usage_text_(usage_text),
142       flag_type_(flag_type) {}
143 
Parse(const std::string & arg,bool * value_parsing_ok) const144 bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const {
145   return ParseFlag(arg, name_, flag_type_ == kPositional, value_hook_,
146                    value_parsing_ok);
147 }
148 
GetTypeName() const149 std::string Flag::GetTypeName() const {
150   switch (type_) {
151     case TYPE_INT32:
152       return "int32";
153     case TYPE_INT64:
154       return "int64";
155     case TYPE_FLOAT:
156       return "float";
157     case TYPE_BOOL:
158       return "bool";
159     case TYPE_STRING:
160       return "string";
161   }
162 
163   return "unknown";
164 }
165 
Parse(int * argc,const char ** argv,const std::vector<Flag> & flag_list)166 /*static*/ bool Flags::Parse(int* argc, const char** argv,
167                              const std::vector<Flag>& flag_list) {
168   bool result = true;
169   std::vector<bool> unknown_argvs(*argc, true);
170   // Record the list of flags that have been processed. key is the flag's name
171   // and the value is the corresponding argv index if there's one, or -1 when
172   // the argv list doesn't contain this flag.
173   std::unordered_map<std::string, int> processed_flags;
174 
175   // Stores indexes of flag_list in a sorted order.
176   std::vector<int> sorted_idx(flag_list.size());
177   std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0);
178   std::sort(sorted_idx.begin(), sorted_idx.end(), [&flag_list](int a, int b) {
179     return flag_list[a].GetFlagType() < flag_list[b].GetFlagType();
180   });
181   int positional_count = 0;
182 
183   for (int idx = 0; idx < sorted_idx.size(); ++idx) {
184     const Flag& flag = flag_list[sorted_idx[idx]];
185 
186     const auto it = processed_flags.find(flag.name_);
187     if (it != processed_flags.end()) {
188 #ifndef NDEBUG
189       // Only log this in debug builds.
190       TFLITE_LOG(WARN) << "Duplicate flags: " << flag.name_;
191 #endif
192       if (it->second != -1) {
193         bool value_parsing_ok;
194         flag.Parse(argv[it->second], &value_parsing_ok);
195         if (!value_parsing_ok) {
196           TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
197                             << "' against argv '" << argv[it->second] << "'";
198           result = false;
199         }
200         continue;
201       } else if (flag.flag_type_ == Flag::kRequired) {
202         TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
203         // If the required flag isn't found, we immediately stop the whole flag
204         // parsing.
205         result = false;
206         break;
207       }
208     }
209 
210     // Parses positional flags.
211     if (flag.flag_type_ == Flag::kPositional) {
212       if (++positional_count >= *argc) {
213         TFLITE_LOG(ERROR) << "Too few command line arguments.";
214         return false;
215       }
216       bool value_parsing_ok;
217       flag.Parse(argv[positional_count], &value_parsing_ok);
218       if (!value_parsing_ok) {
219         TFLITE_LOG(ERROR) << "Failed to parse positional flag: " << flag.name_;
220         return false;
221       }
222       unknown_argvs[positional_count] = false;
223       processed_flags[flag.name_] = positional_count;
224       continue;
225     }
226 
227     // Parse other flags.
228     bool was_found = false;
229     for (int i = positional_count + 1; i < *argc; ++i) {
230       if (!unknown_argvs[i]) continue;
231       bool value_parsing_ok;
232       was_found = flag.Parse(argv[i], &value_parsing_ok);
233       if (!value_parsing_ok) {
234         TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
235                           << "' against argv '" << argv[i] << "'";
236         result = false;
237       }
238       if (was_found) {
239         unknown_argvs[i] = false;
240         processed_flags[flag.name_] = i;
241         break;
242       }
243     }
244 
245     // If the flag is found from the argv (i.e. the flag name appears in argv),
246     // continue to the next flag parsing.
247     if (was_found) continue;
248 
249     // The flag isn't found, do some bookkeeping work.
250     processed_flags[flag.name_] = -1;
251     if (flag.flag_type_ == Flag::kRequired) {
252       TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
253       result = false;
254       // If the required flag isn't found, we immediately stop the whole flag
255       // parsing by breaking the outer-loop (i.e. the 'sorted_idx'-iteration
256       // loop).
257       break;
258     }
259   }
260 
261   int dst = 1;  // Skip argv[0]
262   for (int i = 1; i < *argc; ++i) {
263     if (unknown_argvs[i]) {
264       argv[dst++] = argv[i];
265     }
266   }
267   *argc = dst;
268   return result && (*argc < 2 || std::strcmp(argv[1], "--help") != 0);
269 }
270 
Usage(const std::string & cmdline,const std::vector<Flag> & flag_list)271 /*static*/ std::string Flags::Usage(const std::string& cmdline,
272                                     const std::vector<Flag>& flag_list) {
273   // Stores indexes of flag_list in a sorted order.
274   std::vector<int> sorted_idx(flag_list.size());
275   std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0);
276   std::sort(sorted_idx.begin(), sorted_idx.end(), [&flag_list](int a, int b) {
277     return flag_list[a].GetFlagType() < flag_list[b].GetFlagType();
278   });
279   // Counts number of positional flags will be shown.
280   int positional_count = 0;
281   std::ostringstream usage_text;
282   usage_text << "usage: " << cmdline;
283   // Prints usage for positional flag.
284   for (int i = 0; i < sorted_idx.size(); ++i) {
285     const Flag& flag = flag_list[sorted_idx[i]];
286     if (flag.flag_type_ == Flag::kPositional) {
287       positional_count++;
288       usage_text << " <" << flag.name_ << ">";
289     } else {
290       usage_text << " <flags>";
291       break;
292     }
293   }
294   usage_text << "\n";
295 
296   // Finds the max number of chars of the name column in the usage message.
297   int max_name_width = 0;
298   std::vector<std::string> name_column(flag_list.size());
299   for (int i = 0; i < sorted_idx.size(); ++i) {
300     const Flag& flag = flag_list[sorted_idx[i]];
301     if (flag.flag_type_ != Flag::kPositional) {
302       name_column[i] += "--";
303       name_column[i] += flag.name_;
304       name_column[i] += "=";
305       name_column[i] += flag.default_for_display_;
306     } else {
307       name_column[i] += flag.name_;
308     }
309     if (name_column[i].size() > max_name_width) {
310       max_name_width = name_column[i].size();
311     }
312   }
313 
314   if (positional_count > 0) {
315     usage_text << "Where:\n";
316   }
317   for (int i = 0; i < sorted_idx.size(); ++i) {
318     const Flag& flag = flag_list[sorted_idx[i]];
319     if (i == positional_count) {
320       usage_text << "Flags:\n";
321     }
322     auto type_name = flag.GetTypeName();
323     usage_text << "\t";
324     usage_text << std::left << std::setw(max_name_width) << name_column[i];
325     usage_text << "\t" << type_name << "\t";
326     usage_text << (flag.flag_type_ != Flag::kOptional ? "required"
327                                                       : "optional");
328     usage_text << "\t" << flag.usage_text_ << "\n";
329   }
330   return usage_text.str();
331 }
332 
ArgsToString(int argc,const char ** argv)333 /*static*/ std::string Flags::ArgsToString(int argc, const char** argv) {
334   std::string args;
335   for (int i = 1; i < argc; ++i) {
336     args.append(argv[i]);
337     if (i != argc - 1) args.append(" ");
338   }
339   return args;
340 }
341 
342 }  // namespace tflite
343