1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/model_cmdline_flags.h"
16
17 #include <string>
18 #include <vector>
19
20 #include "absl/strings/numbers.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/strings/strip.h"
25 #include "tensorflow/lite/toco/args.h"
26 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
27 #include "tensorflow/lite/toco/toco_port.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/util/command_line_flags.h"
30
31 // "batch" flag only exists internally
32 #ifdef PLATFORM_GOOGLE
33 #include "base/commandlineflags.h"
34 #endif
35
36 namespace toco {
37
ParseModelFlagsFromCommandLineFlags(int * argc,char * argv[],string * msg,ParsedModelFlags * parsed_model_flags_ptr)38 bool ParseModelFlagsFromCommandLineFlags(
39 int* argc, char* argv[], string* msg,
40 ParsedModelFlags* parsed_model_flags_ptr) {
41 ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr;
42 using tensorflow::Flag;
43 std::vector<tensorflow::Flag> flags = {
44 Flag("input_array", parsed_flags.input_array.bind(),
45 parsed_flags.input_array.default_value(),
46 "Deprecated: use --input_arrays instead. Name of the input array. "
47 "If not specified, will try to read "
48 "that information from the input file."),
49 Flag("input_arrays", parsed_flags.input_arrays.bind(),
50 parsed_flags.input_arrays.default_value(),
51 "Names of the input arrays, comma-separated. If not specified, "
52 "will try to read that information from the input file."),
53 Flag("output_array", parsed_flags.output_array.bind(),
54 parsed_flags.output_array.default_value(),
55 "Deprecated: use --output_arrays instead. Name of the output array, "
56 "when specifying a unique output array. "
57 "If not specified, will try to read that information from the "
58 "input file."),
59 Flag("output_arrays", parsed_flags.output_arrays.bind(),
60 parsed_flags.output_arrays.default_value(),
61 "Names of the output arrays, comma-separated. "
62 "If not specified, will try to read "
63 "that information from the input file."),
64 Flag("input_shape", parsed_flags.input_shape.bind(),
65 parsed_flags.input_shape.default_value(),
66 "Deprecated: use --input_shapes instead. Input array shape. For "
67 "many models the shape takes the form "
68 "batch size, input array height, input array width, input array "
69 "depth."),
70 Flag("input_shapes", parsed_flags.input_shapes.bind(),
71 parsed_flags.input_shapes.default_value(),
72 "Shapes corresponding to --input_arrays, colon-separated. For "
73 "many models each shape takes the form batch size, input array "
74 "height, input array width, input array depth."),
75 Flag("batch_size", parsed_flags.batch_size.bind(),
76 parsed_flags.batch_size.default_value(),
77 "Deprecated. Batch size for the model. Replaces the first dimension "
78 "of an input size array if undefined. Use only with SavedModels "
79 "when --input_shapes flag is not specified. Always use "
80 "--input_shapes flag with frozen graphs."),
81 Flag("input_data_type", parsed_flags.input_data_type.bind(),
82 parsed_flags.input_data_type.default_value(),
83 "Deprecated: use --input_data_types instead. Input array type, if "
84 "not already provided in the graph. "
85 "Typically needs to be specified when passing arbitrary arrays "
86 "to --input_arrays."),
87 Flag("input_data_types", parsed_flags.input_data_types.bind(),
88 parsed_flags.input_data_types.default_value(),
89 "Input arrays types, comma-separated, if not already provided in "
90 "the graph. "
91 "Typically needs to be specified when passing arbitrary arrays "
92 "to --input_arrays."),
93 Flag("mean_value", parsed_flags.mean_value.bind(),
94 parsed_flags.mean_value.default_value(),
95 "Deprecated: use --mean_values instead. mean_value parameter for "
96 "image models, used to compute input "
97 "activations from input pixel data."),
98 Flag("mean_values", parsed_flags.mean_values.bind(),
99 parsed_flags.mean_values.default_value(),
100 "mean_values parameter for image models, comma-separated list of "
101 "doubles, used to compute input activations from input pixel "
102 "data. Each entry in the list should match an entry in "
103 "--input_arrays."),
104 Flag("std_value", parsed_flags.std_value.bind(),
105 parsed_flags.std_value.default_value(),
106 "Deprecated: use --std_values instead. std_value parameter for "
107 "image models, used to compute input "
108 "activations from input pixel data."),
109 Flag("std_values", parsed_flags.std_values.bind(),
110 parsed_flags.std_values.default_value(),
111 "std_value parameter for image models, comma-separated list of "
112 "doubles, used to compute input activations from input pixel "
113 "data. Each entry in the list should match an entry in "
114 "--input_arrays."),
115 Flag("variable_batch", parsed_flags.variable_batch.bind(),
116 parsed_flags.variable_batch.default_value(),
117 "If true, the model accepts an arbitrary batch size. Mutually "
118 "exclusive "
119 "with the 'batch' field: at most one of these two fields can be "
120 "set."),
121 Flag("rnn_states", parsed_flags.rnn_states.bind(),
122 parsed_flags.rnn_states.default_value(), ""),
123 Flag("model_checks", parsed_flags.model_checks.bind(),
124 parsed_flags.model_checks.default_value(),
125 "A list of model checks to be applied to verify the form of the "
126 "model. Applied after the graph transformations after import."),
127 Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(),
128 parsed_flags.dump_graphviz.default_value(),
129 "Dump graphviz during LogDump call. If string is non-empty then "
130 "it defines path to dump, otherwise will skip dumping."),
131 Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(),
132 parsed_flags.dump_graphviz_video.default_value(),
133 "If true, will dump graphviz at each "
134 "graph transformation, which may be used to generate a video."),
135 Flag("allow_nonexistent_arrays",
136 parsed_flags.allow_nonexistent_arrays.bind(),
137 parsed_flags.allow_nonexistent_arrays.default_value(),
138 "If true, will allow passing inexistent arrays in --input_arrays "
139 "and --output_arrays. This makes little sense, is only useful to "
140 "more easily get graph visualizations."),
141 Flag("allow_nonascii_arrays", parsed_flags.allow_nonascii_arrays.bind(),
142 parsed_flags.allow_nonascii_arrays.default_value(),
143 "If true, will allow passing non-ascii-printable characters in "
144 "--input_arrays and --output_arrays. By default (if false), only "
145 "ascii printable characters are allowed, i.e. character codes "
146 "ranging from 32 to 127. This is disallowed by default so as to "
147 "catch common copy-and-paste issues where invisible unicode "
148 "characters are unwittingly added to these strings."),
149 Flag(
150 "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(),
151 parsed_flags.arrays_extra_info_file.default_value(),
152 "Path to an optional file containing a serialized ArraysExtraInfo "
153 "proto allowing to pass extra information about arrays not specified "
154 "in the input model file, such as extra MinMax information."),
155 Flag("model_flags_file", parsed_flags.model_flags_file.bind(),
156 parsed_flags.model_flags_file.default_value(),
157 "Path to an optional file containing a serialized ModelFlags proto. "
158 "Options specified on the command line will override the values in "
159 "the proto."),
160 Flag("change_concat_input_ranges",
161 parsed_flags.change_concat_input_ranges.bind(),
162 parsed_flags.change_concat_input_ranges.default_value(),
163 "Boolean to change the behavior of min/max ranges for inputs and"
164 " output of the concat operators."),
165 };
166 bool asked_for_help =
167 *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
168 if (asked_for_help) {
169 *msg += tensorflow::Flags::Usage(argv[0], flags);
170 return false;
171 } else {
172 if (!tensorflow::Flags::Parse(argc, argv, flags)) return false;
173 }
174 auto& dump_options = *GraphVizDumpOptions::singleton();
175 dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value();
176 dump_options.dump_graphviz = parsed_flags.dump_graphviz.value();
177
178 return true;
179 }
180
ReadModelFlagsFromCommandLineFlags(const ParsedModelFlags & parsed_model_flags,ModelFlags * model_flags)181 void ReadModelFlagsFromCommandLineFlags(
182 const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
183 toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");
184
185 // Load proto containing the initial model flags.
186 // Additional flags specified on the command line will overwrite the values.
187 if (parsed_model_flags.model_flags_file.specified()) {
188 string model_flags_file_contents;
189 QCHECK(port::file::GetContents(parsed_model_flags.model_flags_file.value(),
190 &model_flags_file_contents,
191 port::file::Defaults())
192 .ok())
193 << "Specified --model_flags_file="
194 << parsed_model_flags.model_flags_file.value()
195 << " was not found or could not be read";
196 QCHECK(ParseFromStringEitherTextOrBinary(model_flags_file_contents,
197 model_flags))
198 << "Specified --model_flags_file="
199 << parsed_model_flags.model_flags_file.value()
200 << " could not be parsed";
201 }
202
203 #ifdef PLATFORM_GOOGLE
204 CHECK(!((base::SpecifiedOnCommandLine("batch") &&
205 parsed_model_flags.variable_batch.specified())))
206 << "The --batch and --variable_batch flags are mutually exclusive.";
207 #endif
208 CHECK(!(parsed_model_flags.output_array.specified() &&
209 parsed_model_flags.output_arrays.specified()))
210 << "The --output_array and --vs flags are mutually exclusive.";
211
212 if (parsed_model_flags.output_array.specified()) {
213 model_flags->add_output_arrays(parsed_model_flags.output_array.value());
214 }
215
216 if (parsed_model_flags.output_arrays.specified()) {
217 std::vector<string> output_arrays =
218 absl::StrSplit(parsed_model_flags.output_arrays.value(), ',');
219 for (const string& output_array : output_arrays) {
220 model_flags->add_output_arrays(output_array);
221 }
222 }
223
224 const bool uses_single_input_flags =
225 parsed_model_flags.input_array.specified() ||
226 parsed_model_flags.mean_value.specified() ||
227 parsed_model_flags.std_value.specified() ||
228 parsed_model_flags.input_shape.specified();
229
230 const bool uses_multi_input_flags =
231 parsed_model_flags.input_arrays.specified() ||
232 parsed_model_flags.mean_values.specified() ||
233 parsed_model_flags.std_values.specified() ||
234 parsed_model_flags.input_shapes.specified();
235
236 QCHECK(!(uses_single_input_flags && uses_multi_input_flags))
237 << "Use either the singular-form input flags (--input_array, "
238 "--input_shape, --mean_value, --std_value) or the plural form input "
239 "flags (--input_arrays, --input_shapes, --mean_values, --std_values), "
240 "but not both forms within the same command line.";
241
242 if (parsed_model_flags.input_array.specified()) {
243 QCHECK(uses_single_input_flags);
244 model_flags->add_input_arrays()->set_name(
245 parsed_model_flags.input_array.value());
246 }
247 if (parsed_model_flags.input_arrays.specified()) {
248 QCHECK(uses_multi_input_flags);
249 for (const auto& input_array :
250 absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) {
251 model_flags->add_input_arrays()->set_name(string(input_array));
252 }
253 }
254 if (parsed_model_flags.mean_value.specified()) {
255 QCHECK(uses_single_input_flags);
256 model_flags->mutable_input_arrays(0)->set_mean_value(
257 parsed_model_flags.mean_value.value());
258 }
259 if (parsed_model_flags.mean_values.specified()) {
260 QCHECK(uses_multi_input_flags);
261 std::vector<string> mean_values =
262 absl::StrSplit(parsed_model_flags.mean_values.value(), ',');
263 QCHECK(mean_values.size() == model_flags->input_arrays_size());
264 for (size_t i = 0; i < mean_values.size(); ++i) {
265 char* last = nullptr;
266 model_flags->mutable_input_arrays(i)->set_mean_value(
267 strtod(mean_values[i].data(), &last));
268 CHECK(last != mean_values[i].data());
269 }
270 }
271 if (parsed_model_flags.std_value.specified()) {
272 QCHECK(uses_single_input_flags);
273 model_flags->mutable_input_arrays(0)->set_std_value(
274 parsed_model_flags.std_value.value());
275 }
276 if (parsed_model_flags.std_values.specified()) {
277 QCHECK(uses_multi_input_flags);
278 std::vector<string> std_values =
279 absl::StrSplit(parsed_model_flags.std_values.value(), ',');
280 QCHECK(std_values.size() == model_flags->input_arrays_size());
281 for (size_t i = 0; i < std_values.size(); ++i) {
282 char* last = nullptr;
283 model_flags->mutable_input_arrays(i)->set_std_value(
284 strtod(std_values[i].data(), &last));
285 CHECK(last != std_values[i].data());
286 }
287 }
288 if (parsed_model_flags.input_data_type.specified()) {
289 QCHECK(uses_single_input_flags);
290 IODataType type;
291 QCHECK(IODataType_Parse(parsed_model_flags.input_data_type.value(), &type));
292 model_flags->mutable_input_arrays(0)->set_data_type(type);
293 }
294 if (parsed_model_flags.input_data_types.specified()) {
295 QCHECK(uses_multi_input_flags);
296 std::vector<string> input_data_types =
297 absl::StrSplit(parsed_model_flags.input_data_types.value(), ',');
298 QCHECK(input_data_types.size() == model_flags->input_arrays_size());
299 for (size_t i = 0; i < input_data_types.size(); ++i) {
300 IODataType type;
301 QCHECK(IODataType_Parse(input_data_types[i], &type));
302 model_flags->mutable_input_arrays(i)->set_data_type(type);
303 }
304 }
305 if (parsed_model_flags.input_shape.specified()) {
306 QCHECK(uses_single_input_flags);
307 if (model_flags->input_arrays().empty()) {
308 model_flags->add_input_arrays();
309 }
310 auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape();
311 shape->clear_dims();
312 const IntList& list = parsed_model_flags.input_shape.value();
313 for (auto& dim : list.elements) {
314 shape->add_dims(dim);
315 }
316 }
317 if (parsed_model_flags.input_shapes.specified()) {
318 QCHECK(uses_multi_input_flags);
319 std::vector<string> input_shapes =
320 absl::StrSplit(parsed_model_flags.input_shapes.value(), ':');
321 QCHECK(input_shapes.size() == model_flags->input_arrays_size());
322 for (size_t i = 0; i < input_shapes.size(); ++i) {
323 auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape();
324 shape->clear_dims();
325 // Treat an empty input shape as a scalar.
326 if (input_shapes[i].empty()) {
327 continue;
328 }
329 for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) {
330 int size;
331 CHECK(absl::SimpleAtoi(dim_str, &size))
332 << "Failed to parse input_shape: " << input_shapes[i];
333 shape->add_dims(size);
334 }
335 }
336 }
337
338 #define READ_MODEL_FLAG(name) \
339 do { \
340 if (parsed_model_flags.name.specified()) { \
341 model_flags->set_##name(parsed_model_flags.name.value()); \
342 } \
343 } while (false)
344
345 READ_MODEL_FLAG(variable_batch);
346
347 #undef READ_MODEL_FLAG
348
349 for (const auto& element : parsed_model_flags.rnn_states.value().elements) {
350 auto* rnn_state_proto = model_flags->add_rnn_states();
351 for (const auto& kv_pair : element) {
352 const string& key = kv_pair.first;
353 const string& value = kv_pair.second;
354 if (key == "state_array") {
355 rnn_state_proto->set_state_array(value);
356 } else if (key == "back_edge_source_array") {
357 rnn_state_proto->set_back_edge_source_array(value);
358 } else if (key == "size") {
359 int32 size = 0;
360 CHECK(absl::SimpleAtoi(value, &size));
361 CHECK_GT(size, 0);
362 rnn_state_proto->set_size(size);
363 } else {
364 LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
365 }
366 }
367 CHECK(rnn_state_proto->has_state_array() &&
368 rnn_state_proto->has_back_edge_source_array() &&
369 rnn_state_proto->has_size())
370 << "--rnn_states must include state_array, back_edge_source_array and "
371 "size.";
372 }
373
374 for (const auto& element : parsed_model_flags.model_checks.value().elements) {
375 auto* model_check_proto = model_flags->add_model_checks();
376 for (const auto& kv_pair : element) {
377 const string& key = kv_pair.first;
378 const string& value = kv_pair.second;
379 if (key == "count_type") {
380 model_check_proto->set_count_type(value);
381 } else if (key == "count_min") {
382 int32 count = 0;
383 CHECK(absl::SimpleAtoi(value, &count));
384 CHECK_GE(count, -1);
385 model_check_proto->set_count_min(count);
386 } else if (key == "count_max") {
387 int32 count = 0;
388 CHECK(absl::SimpleAtoi(value, &count));
389 CHECK_GE(count, -1);
390 model_check_proto->set_count_max(count);
391 } else {
392 LOG(FATAL) << "Unknown key '" << key << "' in --model_checks";
393 }
394 }
395 }
396
397 if (!model_flags->has_allow_nonascii_arrays()) {
398 model_flags->set_allow_nonascii_arrays(
399 parsed_model_flags.allow_nonascii_arrays.value());
400 }
401 if (!model_flags->has_allow_nonexistent_arrays()) {
402 model_flags->set_allow_nonexistent_arrays(
403 parsed_model_flags.allow_nonexistent_arrays.value());
404 }
405 if (!model_flags->has_change_concat_input_ranges()) {
406 model_flags->set_change_concat_input_ranges(
407 parsed_model_flags.change_concat_input_ranges.value());
408 }
409
410 if (parsed_model_flags.arrays_extra_info_file.specified()) {
411 string arrays_extra_info_file_contents;
412 CHECK(port::file::GetContents(
413 parsed_model_flags.arrays_extra_info_file.value(),
414 &arrays_extra_info_file_contents, port::file::Defaults())
415 .ok());
416 ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
417 model_flags->mutable_arrays_extra_info());
418 }
419 }
420
UncheckedGlobalParsedModelFlags(bool must_already_exist)421 ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
422 static auto* flags = [must_already_exist]() {
423 if (must_already_exist) {
424 fprintf(stderr, __FILE__
425 ":"
426 "GlobalParsedModelFlags() used without initialization\n");
427 fflush(stderr);
428 abort();
429 }
430 return new toco::ParsedModelFlags;
431 }();
432 return flags;
433 }
434
GlobalParsedModelFlags()435 ParsedModelFlags* GlobalParsedModelFlags() {
436 return UncheckedGlobalParsedModelFlags(true);
437 }
438
ParseModelFlagsOrDie(int * argc,char * argv[])439 void ParseModelFlagsOrDie(int* argc, char* argv[]) {
440 // TODO(aselle): in the future allow Google version to use
441 // flags, and only use this mechanism for open source
442 auto* flags = UncheckedGlobalParsedModelFlags(false);
443 string msg;
444 bool model_success =
445 toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags);
446 if (!model_success || !msg.empty()) {
447 // Log in non-standard way since this happens pre InitGoogle.
448 fprintf(stderr, "%s", msg.c_str());
449 fflush(stderr);
450 abort();
451 }
452 }
453
454 } // namespace toco
455