• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/model_cmdline_flags.h"
16 
17 #include <string>
18 #include <vector>
19 
20 #include "absl/strings/numbers.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/strings/strip.h"
25 #include "tensorflow/lite/toco/args.h"
26 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
27 #include "tensorflow/lite/toco/toco_port.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/util/command_line_flags.h"
30 
31 // "batch" flag only exists internally
32 #ifdef PLATFORM_GOOGLE
33 #include "base/commandlineflags.h"
34 #endif
35 
36 namespace toco {
37 
ParseModelFlagsFromCommandLineFlags(int * argc,char * argv[],string * msg,ParsedModelFlags * parsed_model_flags_ptr)38 bool ParseModelFlagsFromCommandLineFlags(
39     int* argc, char* argv[], string* msg,
40     ParsedModelFlags* parsed_model_flags_ptr) {
41   ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr;
42   using tensorflow::Flag;
43   std::vector<tensorflow::Flag> flags = {
44       Flag("input_array", parsed_flags.input_array.bind(),
45            parsed_flags.input_array.default_value(),
46            "Deprecated: use --input_arrays instead. Name of the input array. "
47            "If not specified, will try to read "
48            "that information from the input file."),
49       Flag("input_arrays", parsed_flags.input_arrays.bind(),
50            parsed_flags.input_arrays.default_value(),
51            "Names of the input arrays, comma-separated. If not specified, "
52            "will try to read that information from the input file."),
53       Flag("output_array", parsed_flags.output_array.bind(),
54            parsed_flags.output_array.default_value(),
55            "Deprecated: use --output_arrays instead. Name of the output array, "
56            "when specifying a unique output array. "
57            "If not specified, will try to read that information from the "
58            "input file."),
59       Flag("output_arrays", parsed_flags.output_arrays.bind(),
60            parsed_flags.output_arrays.default_value(),
61            "Names of the output arrays, comma-separated. "
62            "If not specified, will try to read "
63            "that information from the input file."),
64       Flag("input_shape", parsed_flags.input_shape.bind(),
65            parsed_flags.input_shape.default_value(),
66            "Deprecated: use --input_shapes instead. Input array shape. For "
67            "many models the shape takes the form "
68            "batch size, input array height, input array width, input array "
69            "depth."),
70       Flag("input_shapes", parsed_flags.input_shapes.bind(),
71            parsed_flags.input_shapes.default_value(),
72            "Shapes corresponding to --input_arrays, colon-separated. For "
73            "many models each shape takes the form batch size, input array "
74            "height, input array width, input array depth."),
75       Flag("batch_size", parsed_flags.batch_size.bind(),
76            parsed_flags.batch_size.default_value(),
77            "Deprecated. Batch size for the model. Replaces the first dimension "
78            "of an input size array if undefined. Use only with SavedModels "
79            "when --input_shapes flag is not specified. Always use "
80            "--input_shapes flag with frozen graphs."),
81       Flag("input_data_type", parsed_flags.input_data_type.bind(),
82            parsed_flags.input_data_type.default_value(),
83            "Deprecated: use --input_data_types instead. Input array type, if "
84            "not already provided in the graph. "
85            "Typically needs to be specified when passing arbitrary arrays "
86            "to --input_arrays."),
87       Flag("input_data_types", parsed_flags.input_data_types.bind(),
88            parsed_flags.input_data_types.default_value(),
89            "Input arrays types, comma-separated, if not already provided in "
90            "the graph. "
91            "Typically needs to be specified when passing arbitrary arrays "
92            "to --input_arrays."),
93       Flag("mean_value", parsed_flags.mean_value.bind(),
94            parsed_flags.mean_value.default_value(),
95            "Deprecated: use --mean_values instead. mean_value parameter for "
96            "image models, used to compute input "
97            "activations from input pixel data."),
98       Flag("mean_values", parsed_flags.mean_values.bind(),
99            parsed_flags.mean_values.default_value(),
100            "mean_values parameter for image models, comma-separated list of "
101            "doubles, used to compute input activations from input pixel "
102            "data. Each entry in the list should match an entry in "
103            "--input_arrays."),
104       Flag("std_value", parsed_flags.std_value.bind(),
105            parsed_flags.std_value.default_value(),
106            "Deprecated: use --std_values instead. std_value parameter for "
107            "image models, used to compute input "
108            "activations from input pixel data."),
109       Flag("std_values", parsed_flags.std_values.bind(),
110            parsed_flags.std_values.default_value(),
111            "std_value parameter for image models, comma-separated list of "
112            "doubles, used to compute input activations from input pixel "
113            "data. Each entry in the list should match an entry in "
114            "--input_arrays."),
115       Flag("variable_batch", parsed_flags.variable_batch.bind(),
116            parsed_flags.variable_batch.default_value(),
117            "If true, the model accepts an arbitrary batch size. Mutually "
118            "exclusive "
119            "with the 'batch' field: at most one of these two fields can be "
120            "set."),
121       Flag("rnn_states", parsed_flags.rnn_states.bind(),
122            parsed_flags.rnn_states.default_value(), ""),
123       Flag("model_checks", parsed_flags.model_checks.bind(),
124            parsed_flags.model_checks.default_value(),
125            "A list of model checks to be applied to verify the form of the "
126            "model.  Applied after the graph transformations after import."),
127       Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(),
128            parsed_flags.dump_graphviz.default_value(),
129            "Dump graphviz during LogDump call. If string is non-empty then "
130            "it defines path to dump, otherwise will skip dumping."),
131       Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(),
132            parsed_flags.dump_graphviz_video.default_value(),
133            "If true, will dump graphviz at each "
134            "graph transformation, which may be used to generate a video."),
135       Flag("allow_nonexistent_arrays",
136            parsed_flags.allow_nonexistent_arrays.bind(),
137            parsed_flags.allow_nonexistent_arrays.default_value(),
138            "If true, will allow passing inexistent arrays in --input_arrays "
139            "and --output_arrays. This makes little sense, is only useful to "
140            "more easily get graph visualizations."),
141       Flag("allow_nonascii_arrays", parsed_flags.allow_nonascii_arrays.bind(),
142            parsed_flags.allow_nonascii_arrays.default_value(),
143            "If true, will allow passing non-ascii-printable characters in "
144            "--input_arrays and --output_arrays. By default (if false), only "
145            "ascii printable characters are allowed, i.e. character codes "
146            "ranging from 32 to 127. This is disallowed by default so as to "
147            "catch common copy-and-paste issues where invisible unicode "
148            "characters are unwittingly added to these strings."),
149       Flag(
150           "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(),
151           parsed_flags.arrays_extra_info_file.default_value(),
152           "Path to an optional file containing a serialized ArraysExtraInfo "
153           "proto allowing to pass extra information about arrays not specified "
154           "in the input model file, such as extra MinMax information."),
155       Flag("model_flags_file", parsed_flags.model_flags_file.bind(),
156            parsed_flags.model_flags_file.default_value(),
157            "Path to an optional file containing a serialized ModelFlags proto. "
158            "Options specified on the command line will override the values in "
159            "the proto."),
160       Flag("change_concat_input_ranges",
161            parsed_flags.change_concat_input_ranges.bind(),
162            parsed_flags.change_concat_input_ranges.default_value(),
163            "Boolean to change the behavior of min/max ranges for inputs and"
164            " output of the concat operators."),
165   };
166   bool asked_for_help =
167       *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
168   if (asked_for_help) {
169     *msg += tensorflow::Flags::Usage(argv[0], flags);
170     return false;
171   } else {
172     if (!tensorflow::Flags::Parse(argc, argv, flags)) return false;
173   }
174   auto& dump_options = *GraphVizDumpOptions::singleton();
175   dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value();
176   dump_options.dump_graphviz = parsed_flags.dump_graphviz.value();
177 
178   return true;
179 }
180 
ReadModelFlagsFromCommandLineFlags(const ParsedModelFlags & parsed_model_flags,ModelFlags * model_flags)181 void ReadModelFlagsFromCommandLineFlags(
182     const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
183   toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");
184 
185   // Load proto containing the initial model flags.
186   // Additional flags specified on the command line will overwrite the values.
187   if (parsed_model_flags.model_flags_file.specified()) {
188     string model_flags_file_contents;
189     QCHECK(port::file::GetContents(parsed_model_flags.model_flags_file.value(),
190                                    &model_flags_file_contents,
191                                    port::file::Defaults())
192                .ok())
193         << "Specified --model_flags_file="
194         << parsed_model_flags.model_flags_file.value()
195         << " was not found or could not be read";
196     QCHECK(ParseFromStringEitherTextOrBinary(model_flags_file_contents,
197                                              model_flags))
198         << "Specified --model_flags_file="
199         << parsed_model_flags.model_flags_file.value()
200         << " could not be parsed";
201   }
202 
203 #ifdef PLATFORM_GOOGLE
204   CHECK(!((base::SpecifiedOnCommandLine("batch") &&
205            parsed_model_flags.variable_batch.specified())))
206       << "The --batch and --variable_batch flags are mutually exclusive.";
207 #endif
208   CHECK(!(parsed_model_flags.output_array.specified() &&
209           parsed_model_flags.output_arrays.specified()))
210       << "The --output_array and --vs flags are mutually exclusive.";
211 
212   if (parsed_model_flags.output_array.specified()) {
213     model_flags->add_output_arrays(parsed_model_flags.output_array.value());
214   }
215 
216   if (parsed_model_flags.output_arrays.specified()) {
217     std::vector<string> output_arrays =
218         absl::StrSplit(parsed_model_flags.output_arrays.value(), ',');
219     for (const string& output_array : output_arrays) {
220       model_flags->add_output_arrays(output_array);
221     }
222   }
223 
224   const bool uses_single_input_flags =
225       parsed_model_flags.input_array.specified() ||
226       parsed_model_flags.mean_value.specified() ||
227       parsed_model_flags.std_value.specified() ||
228       parsed_model_flags.input_shape.specified();
229 
230   const bool uses_multi_input_flags =
231       parsed_model_flags.input_arrays.specified() ||
232       parsed_model_flags.mean_values.specified() ||
233       parsed_model_flags.std_values.specified() ||
234       parsed_model_flags.input_shapes.specified();
235 
236   QCHECK(!(uses_single_input_flags && uses_multi_input_flags))
237       << "Use either the singular-form input flags (--input_array, "
238          "--input_shape, --mean_value, --std_value) or the plural form input "
239          "flags (--input_arrays, --input_shapes, --mean_values, --std_values), "
240          "but not both forms within the same command line.";
241 
242   if (parsed_model_flags.input_array.specified()) {
243     QCHECK(uses_single_input_flags);
244     model_flags->add_input_arrays()->set_name(
245         parsed_model_flags.input_array.value());
246   }
247   if (parsed_model_flags.input_arrays.specified()) {
248     QCHECK(uses_multi_input_flags);
249     for (const auto& input_array :
250          absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) {
251       model_flags->add_input_arrays()->set_name(string(input_array));
252     }
253   }
254   if (parsed_model_flags.mean_value.specified()) {
255     QCHECK(uses_single_input_flags);
256     model_flags->mutable_input_arrays(0)->set_mean_value(
257         parsed_model_flags.mean_value.value());
258   }
259   if (parsed_model_flags.mean_values.specified()) {
260     QCHECK(uses_multi_input_flags);
261     std::vector<string> mean_values =
262         absl::StrSplit(parsed_model_flags.mean_values.value(), ',');
263     QCHECK(mean_values.size() == model_flags->input_arrays_size());
264     for (size_t i = 0; i < mean_values.size(); ++i) {
265       char* last = nullptr;
266       model_flags->mutable_input_arrays(i)->set_mean_value(
267           strtod(mean_values[i].data(), &last));
268       CHECK(last != mean_values[i].data());
269     }
270   }
271   if (parsed_model_flags.std_value.specified()) {
272     QCHECK(uses_single_input_flags);
273     model_flags->mutable_input_arrays(0)->set_std_value(
274         parsed_model_flags.std_value.value());
275   }
276   if (parsed_model_flags.std_values.specified()) {
277     QCHECK(uses_multi_input_flags);
278     std::vector<string> std_values =
279         absl::StrSplit(parsed_model_flags.std_values.value(), ',');
280     QCHECK(std_values.size() == model_flags->input_arrays_size());
281     for (size_t i = 0; i < std_values.size(); ++i) {
282       char* last = nullptr;
283       model_flags->mutable_input_arrays(i)->set_std_value(
284           strtod(std_values[i].data(), &last));
285       CHECK(last != std_values[i].data());
286     }
287   }
288   if (parsed_model_flags.input_data_type.specified()) {
289     QCHECK(uses_single_input_flags);
290     IODataType type;
291     QCHECK(IODataType_Parse(parsed_model_flags.input_data_type.value(), &type));
292     model_flags->mutable_input_arrays(0)->set_data_type(type);
293   }
294   if (parsed_model_flags.input_data_types.specified()) {
295     QCHECK(uses_multi_input_flags);
296     std::vector<string> input_data_types =
297         absl::StrSplit(parsed_model_flags.input_data_types.value(), ',');
298     QCHECK(input_data_types.size() == model_flags->input_arrays_size());
299     for (size_t i = 0; i < input_data_types.size(); ++i) {
300       IODataType type;
301       QCHECK(IODataType_Parse(input_data_types[i], &type));
302       model_flags->mutable_input_arrays(i)->set_data_type(type);
303     }
304   }
305   if (parsed_model_flags.input_shape.specified()) {
306     QCHECK(uses_single_input_flags);
307     if (model_flags->input_arrays().empty()) {
308       model_flags->add_input_arrays();
309     }
310     auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape();
311     shape->clear_dims();
312     const IntList& list = parsed_model_flags.input_shape.value();
313     for (auto& dim : list.elements) {
314       shape->add_dims(dim);
315     }
316   }
317   if (parsed_model_flags.input_shapes.specified()) {
318     QCHECK(uses_multi_input_flags);
319     std::vector<string> input_shapes =
320         absl::StrSplit(parsed_model_flags.input_shapes.value(), ':');
321     QCHECK(input_shapes.size() == model_flags->input_arrays_size());
322     for (size_t i = 0; i < input_shapes.size(); ++i) {
323       auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape();
324       shape->clear_dims();
325       // Treat an empty input shape as a scalar.
326       if (input_shapes[i].empty()) {
327         continue;
328       }
329       for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) {
330         int size;
331         CHECK(absl::SimpleAtoi(dim_str, &size))
332             << "Failed to parse input_shape: " << input_shapes[i];
333         shape->add_dims(size);
334       }
335     }
336   }
337 
338 #define READ_MODEL_FLAG(name)                                   \
339   do {                                                          \
340     if (parsed_model_flags.name.specified()) {                  \
341       model_flags->set_##name(parsed_model_flags.name.value()); \
342     }                                                           \
343   } while (false)
344 
345   READ_MODEL_FLAG(variable_batch);
346 
347 #undef READ_MODEL_FLAG
348 
349   for (const auto& element : parsed_model_flags.rnn_states.value().elements) {
350     auto* rnn_state_proto = model_flags->add_rnn_states();
351     for (const auto& kv_pair : element) {
352       const string& key = kv_pair.first;
353       const string& value = kv_pair.second;
354       if (key == "state_array") {
355         rnn_state_proto->set_state_array(value);
356       } else if (key == "back_edge_source_array") {
357         rnn_state_proto->set_back_edge_source_array(value);
358       } else if (key == "size") {
359         int32 size = 0;
360         CHECK(absl::SimpleAtoi(value, &size));
361         CHECK_GT(size, 0);
362         rnn_state_proto->set_size(size);
363       } else {
364         LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
365       }
366     }
367     CHECK(rnn_state_proto->has_state_array() &&
368           rnn_state_proto->has_back_edge_source_array() &&
369           rnn_state_proto->has_size())
370         << "--rnn_states must include state_array, back_edge_source_array and "
371            "size.";
372   }
373 
374   for (const auto& element : parsed_model_flags.model_checks.value().elements) {
375     auto* model_check_proto = model_flags->add_model_checks();
376     for (const auto& kv_pair : element) {
377       const string& key = kv_pair.first;
378       const string& value = kv_pair.second;
379       if (key == "count_type") {
380         model_check_proto->set_count_type(value);
381       } else if (key == "count_min") {
382         int32 count = 0;
383         CHECK(absl::SimpleAtoi(value, &count));
384         CHECK_GE(count, -1);
385         model_check_proto->set_count_min(count);
386       } else if (key == "count_max") {
387         int32 count = 0;
388         CHECK(absl::SimpleAtoi(value, &count));
389         CHECK_GE(count, -1);
390         model_check_proto->set_count_max(count);
391       } else {
392         LOG(FATAL) << "Unknown key '" << key << "' in --model_checks";
393       }
394     }
395   }
396 
397   if (!model_flags->has_allow_nonascii_arrays()) {
398     model_flags->set_allow_nonascii_arrays(
399         parsed_model_flags.allow_nonascii_arrays.value());
400   }
401   if (!model_flags->has_allow_nonexistent_arrays()) {
402     model_flags->set_allow_nonexistent_arrays(
403         parsed_model_flags.allow_nonexistent_arrays.value());
404   }
405   if (!model_flags->has_change_concat_input_ranges()) {
406     model_flags->set_change_concat_input_ranges(
407         parsed_model_flags.change_concat_input_ranges.value());
408   }
409 
410   if (parsed_model_flags.arrays_extra_info_file.specified()) {
411     string arrays_extra_info_file_contents;
412     CHECK(port::file::GetContents(
413               parsed_model_flags.arrays_extra_info_file.value(),
414               &arrays_extra_info_file_contents, port::file::Defaults())
415               .ok());
416     ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
417                                       model_flags->mutable_arrays_extra_info());
418   }
419 }
420 
UncheckedGlobalParsedModelFlags(bool must_already_exist)421 ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
422   static auto* flags = [must_already_exist]() {
423     if (must_already_exist) {
424       fprintf(stderr, __FILE__
425               ":"
426               "GlobalParsedModelFlags() used without initialization\n");
427       fflush(stderr);
428       abort();
429     }
430     return new toco::ParsedModelFlags;
431   }();
432   return flags;
433 }
434 
GlobalParsedModelFlags()435 ParsedModelFlags* GlobalParsedModelFlags() {
436   return UncheckedGlobalParsedModelFlags(true);
437 }
438 
ParseModelFlagsOrDie(int * argc,char * argv[])439 void ParseModelFlagsOrDie(int* argc, char* argv[]) {
440   // TODO(aselle): in the future allow Google version to use
441   // flags, and only use this mechanism for open source
442   auto* flags = UncheckedGlobalParsedModelFlags(false);
443   string msg;
444   bool model_success =
445       toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags);
446   if (!model_success || !msg.empty()) {
447     // Log in non-standard way since this happens pre InitGoogle.
448     fprintf(stderr, "%s", msg.c_str());
449     fflush(stderr);
450     abort();
451   }
452 }
453 
454 }  // namespace toco
455