1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cinttypes>
17 #include <cstring>
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/core/lib/core/stringpiece.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 #include "tensorflow/core/lib/strings/stringprintf.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/util/command_line_flags.h"
26 
27 namespace tensorflow {
28 namespace {
29 
ParseStringFlag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (string)> & hook,bool * value_parsing_ok)30 bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
31                      const std::function<bool(string)>& hook,
32                      bool* value_parsing_ok) {
33   *value_parsing_ok = true;
34   if (str_util::ConsumePrefix(&arg, "--") &&
35       str_util::ConsumePrefix(&arg, flag) &&
36       str_util::ConsumePrefix(&arg, "=")) {
37     *value_parsing_ok = hook(string(arg));
38     return true;
39   }
40 
41   return false;
42 }
43 
ParseInt32Flag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (int32)> & hook,bool * value_parsing_ok)44 bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
45                     const std::function<bool(int32)>& hook,
46                     bool* value_parsing_ok) {
47   *value_parsing_ok = true;
48   if (str_util::ConsumePrefix(&arg, "--") &&
49       str_util::ConsumePrefix(&arg, flag) &&
50       str_util::ConsumePrefix(&arg, "=")) {
51     char extra;
52     int32 parsed_int32;
53     if (sscanf(arg.data(), "%d%c", &parsed_int32, &extra) != 1) {
54       LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
55                  << ".";
56       *value_parsing_ok = false;
57     } else {
58       *value_parsing_ok = hook(parsed_int32);
59     }
60     return true;
61   }
62 
63   return false;
64 }
65 
ParseInt64Flag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (int64)> & hook,bool * value_parsing_ok)66 bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
67                     const std::function<bool(int64)>& hook,
68                     bool* value_parsing_ok) {
69   *value_parsing_ok = true;
70   if (str_util::ConsumePrefix(&arg, "--") &&
71       str_util::ConsumePrefix(&arg, flag) &&
72       str_util::ConsumePrefix(&arg, "=")) {
73     char extra;
74     int64_t parsed_int64;
75     if (sscanf(arg.data(), "%" SCNd64 "%c", &parsed_int64, &extra) != 1) {
76       LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
77                  << ".";
78       *value_parsing_ok = false;
79     } else {
80       *value_parsing_ok = hook(parsed_int64);
81     }
82     return true;
83   }
84 
85   return false;
86 }
87 
ParseBoolFlag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (bool)> & hook,bool * value_parsing_ok)88 bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
89                    const std::function<bool(bool)>& hook,
90                    bool* value_parsing_ok) {
91   *value_parsing_ok = true;
92   if (str_util::ConsumePrefix(&arg, "--") &&
93       str_util::ConsumePrefix(&arg, flag)) {
94     if (arg.empty()) {
95       *value_parsing_ok = hook(true);
96       return true;
97     }
98 
99     if (arg == "=true") {
100       *value_parsing_ok = hook(true);
101       return true;
102     } else if (arg == "=false") {
103       *value_parsing_ok = hook(false);
104       return true;
105     } else {
106       LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
107                  << ".";
108       *value_parsing_ok = false;
109       return true;
110     }
111   }
112 
113   return false;
114 }
115 
ParseFloatFlag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,const std::function<bool (float)> & hook,bool * value_parsing_ok)116 bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
117                     const std::function<bool(float)>& hook,
118                     bool* value_parsing_ok) {
119   *value_parsing_ok = true;
120   if (str_util::ConsumePrefix(&arg, "--") &&
121       str_util::ConsumePrefix(&arg, flag) &&
122       str_util::ConsumePrefix(&arg, "=")) {
123     char extra;
124     float parsed_float;
125     if (sscanf(arg.data(), "%f%c", &parsed_float, &extra) != 1) {
126       LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
127                  << ".";
128       *value_parsing_ok = false;
129     } else {
130       *value_parsing_ok = hook(parsed_float);
131     }
132     return true;
133   }
134 
135   return false;
136 }
137 
138 }  // namespace
139 
Flag(const char * name,tensorflow::int32 * dst,const string & usage_text)140 Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
141     : name_(name),
142       type_(TYPE_INT32),
143       int32_hook_([dst](int32 value) {
144         *dst = value;
145         return true;
146       }),
147       int32_default_for_display_(*dst),
148       usage_text_(usage_text) {}
149 
Flag(const char * name,tensorflow::int64 * dst,const string & usage_text)150 Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
151     : name_(name),
152       type_(TYPE_INT64),
153       int64_hook_([dst](int64 value) {
154         *dst = value;
155         return true;
156       }),
157       int64_default_for_display_(*dst),
158       usage_text_(usage_text) {}
159 
Flag(const char * name,float * dst,const string & usage_text)160 Flag::Flag(const char* name, float* dst, const string& usage_text)
161     : name_(name),
162       type_(TYPE_FLOAT),
163       float_hook_([dst](float value) {
164         *dst = value;
165         return true;
166       }),
167       float_default_for_display_(*dst),
168       usage_text_(usage_text) {}
169 
Flag(const char * name,bool * dst,const string & usage_text)170 Flag::Flag(const char* name, bool* dst, const string& usage_text)
171     : name_(name),
172       type_(TYPE_BOOL),
173       bool_hook_([dst](bool value) {
174         *dst = value;
175         return true;
176       }),
177       bool_default_for_display_(*dst),
178       usage_text_(usage_text) {}
179 
Flag(const char * name,string * dst,const string & usage_text)180 Flag::Flag(const char* name, string* dst, const string& usage_text)
181     : name_(name),
182       type_(TYPE_STRING),
183       string_hook_([dst](string value) {
184         *dst = std::move(value);
185         return true;
186       }),
187       string_default_for_display_(*dst),
188       usage_text_(usage_text) {}
189 
Flag(const char * name,std::function<bool (int32)> int32_hook,int32 default_value_for_display,const string & usage_text)190 Flag::Flag(const char* name, std::function<bool(int32)> int32_hook,
191            int32 default_value_for_display, const string& usage_text)
192     : name_(name),
193       type_(TYPE_INT32),
194       int32_hook_(std::move(int32_hook)),
195       int32_default_for_display_(default_value_for_display),
196       usage_text_(usage_text) {}
197 
Flag(const char * name,std::function<bool (int64)> int64_hook,int64 default_value_for_display,const string & usage_text)198 Flag::Flag(const char* name, std::function<bool(int64)> int64_hook,
199            int64 default_value_for_display, const string& usage_text)
200     : name_(name),
201       type_(TYPE_INT64),
202       int64_hook_(std::move(int64_hook)),
203       int64_default_for_display_(default_value_for_display),
204       usage_text_(usage_text) {}
205 
Flag(const char * name,std::function<bool (float)> float_hook,float default_value_for_display,const string & usage_text)206 Flag::Flag(const char* name, std::function<bool(float)> float_hook,
207            float default_value_for_display, const string& usage_text)
208     : name_(name),
209       type_(TYPE_FLOAT),
210       float_hook_(std::move(float_hook)),
211       float_default_for_display_(default_value_for_display),
212       usage_text_(usage_text) {}
213 
Flag(const char * name,std::function<bool (bool)> bool_hook,bool default_value_for_display,const string & usage_text)214 Flag::Flag(const char* name, std::function<bool(bool)> bool_hook,
215            bool default_value_for_display, const string& usage_text)
216     : name_(name),
217       type_(TYPE_BOOL),
218       bool_hook_(std::move(bool_hook)),
219       bool_default_for_display_(default_value_for_display),
220       usage_text_(usage_text) {}
221 
Flag(const char * name,std::function<bool (string)> string_hook,string default_value_for_display,const string & usage_text)222 Flag::Flag(const char* name, std::function<bool(string)> string_hook,
223            string default_value_for_display, const string& usage_text)
224     : name_(name),
225       type_(TYPE_STRING),
226       string_hook_(std::move(string_hook)),
227       string_default_for_display_(std::move(default_value_for_display)),
228       usage_text_(usage_text) {}
229 
Parse(string arg,bool * value_parsing_ok) const230 bool Flag::Parse(string arg, bool* value_parsing_ok) const {
231   bool result = false;
232   if (type_ == TYPE_INT32) {
233     result = ParseInt32Flag(arg, name_, int32_hook_, value_parsing_ok);
234   } else if (type_ == TYPE_INT64) {
235     result = ParseInt64Flag(arg, name_, int64_hook_, value_parsing_ok);
236   } else if (type_ == TYPE_BOOL) {
237     result = ParseBoolFlag(arg, name_, bool_hook_, value_parsing_ok);
238   } else if (type_ == TYPE_STRING) {
239     result = ParseStringFlag(arg, name_, string_hook_, value_parsing_ok);
240   } else if (type_ == TYPE_FLOAT) {
241     result = ParseFloatFlag(arg, name_, float_hook_, value_parsing_ok);
242   }
243   return result;
244 }
245 
Parse(int * argc,char ** argv,const std::vector<Flag> & flag_list)246 /*static*/ bool Flags::Parse(int* argc, char** argv,
247                              const std::vector<Flag>& flag_list) {
248   bool result = true;
249   std::vector<char*> unknown_flags;
250   for (int i = 1; i < *argc; ++i) {
251     if (string(argv[i]) == "--") {
252       while (i < *argc) {
253         unknown_flags.push_back(argv[i]);
254         ++i;
255       }
256       break;
257     }
258 
259     bool was_found = false;
260     for (const Flag& flag : flag_list) {
261       bool value_parsing_ok;
262       was_found = flag.Parse(argv[i], &value_parsing_ok);
263       if (!value_parsing_ok) {
264         result = false;
265       }
266       if (was_found) {
267         break;
268       }
269     }
270     if (!was_found) {
271       unknown_flags.push_back(argv[i]);
272     }
273   }
274   // Passthrough any extra flags.
275   int dst = 1;  // Skip argv[0]
276   for (char* f : unknown_flags) {
277     argv[dst++] = f;
278   }
279   argv[dst++] = nullptr;
280   *argc = unknown_flags.size() + 1;
281   return result && (*argc < 2 || strcmp(argv[1], "--help") != 0);
282 }
283 
Usage(const string & cmdline,const std::vector<Flag> & flag_list)284 /*static*/ string Flags::Usage(const string& cmdline,
285                                const std::vector<Flag>& flag_list) {
286   string usage_text;
287   if (!flag_list.empty()) {
288     strings::Appendf(&usage_text, "usage: %s\nFlags:\n", cmdline.c_str());
289   } else {
290     strings::Appendf(&usage_text, "usage: %s\n", cmdline.c_str());
291   }
292   for (const Flag& flag : flag_list) {
293     const char* type_name = "";
294     string flag_string;
295     if (flag.type_ == Flag::TYPE_INT32) {
296       type_name = "int32";
297       flag_string = strings::Printf("--%s=%d", flag.name_.c_str(),
298                                     flag.int32_default_for_display_);
299     } else if (flag.type_ == Flag::TYPE_INT64) {
300       type_name = "int64";
301       flag_string = strings::Printf(
302           "--%s=%lld", flag.name_.c_str(),
303           static_cast<long long>(flag.int64_default_for_display_));
304     } else if (flag.type_ == Flag::TYPE_BOOL) {
305       type_name = "bool";
306       flag_string =
307           strings::Printf("--%s=%s", flag.name_.c_str(),
308                           flag.bool_default_for_display_ ? "true" : "false");
309     } else if (flag.type_ == Flag::TYPE_STRING) {
310       type_name = "string";
311       flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
312                                     flag.string_default_for_display_.c_str());
313     } else if (flag.type_ == Flag::TYPE_FLOAT) {
314       type_name = "float";
315       flag_string = strings::Printf("--%s=%f", flag.name_.c_str(),
316                                     flag.float_default_for_display_);
317     }
318     strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
319                      type_name, flag.usage_text_.c_str());
320   }
321   return usage_text;
322 }
323 
324 }  // namespace tensorflow
325