1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python command line interface for running TOCO."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import argparse
22import os
23import sys
24
25from tensorflow.lite.python import lite
26from tensorflow.lite.python import lite_constants
27from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
28from tensorflow.python import tf2
29from tensorflow.python.platform import app
30
31
32def _parse_array(values, type_fn=str):
33  if values is not None:
34    return [type_fn(val) for val in values.split(",") if val]
35  return None
36
37
38def _parse_set(values):
39  if values is not None:
40    return set([item for item in values.split(",") if item])
41  return None
42
43
44def _parse_inference_type(value, flag):
45  """Converts the inference type to the value of the constant.
46
47  Args:
48    value: str representing the inference type.
49    flag: str representing the flag name.
50
51  Returns:
52    tf.dtype.
53
54  Raises:
55    ValueError: Unsupported value.
56  """
57  if value == "FLOAT":
58    return lite_constants.FLOAT
59  if value == "QUANTIZED_UINT8":
60    return lite_constants.QUANTIZED_UINT8
61  raise ValueError("Unsupported value for --{0}. Only FLOAT and "
62                   "QUANTIZED_UINT8 are supported.".format(flag))
63
64
65def _get_toco_converter(flags):
66  """Makes a TFLiteConverter object based on the flags provided.
67
68  Args:
69    flags: argparse.Namespace object containing TFLite flags.
70
71  Returns:
72    TFLiteConverter object.
73
74  Raises:
75    ValueError: Invalid flags.
76  """
77  # Parse input and output arrays.
78  input_arrays = _parse_array(flags.input_arrays)
79  input_shapes = None
80  if flags.input_shapes:
81    input_shapes_list = [
82        _parse_array(shape, type_fn=int)
83        for shape in flags.input_shapes.split(":")
84    ]
85    input_shapes = dict(zip(input_arrays, input_shapes_list))
86  output_arrays = _parse_array(flags.output_arrays)
87
88  converter_kwargs = {
89      "input_arrays": input_arrays,
90      "input_shapes": input_shapes,
91      "output_arrays": output_arrays
92  }
93
94  # Create TFLiteConverter.
95  if flags.graph_def_file:
96    converter_fn = lite.TFLiteConverter.from_frozen_graph
97    converter_kwargs["graph_def_file"] = flags.graph_def_file
98  elif flags.saved_model_dir:
99    converter_fn = lite.TFLiteConverter.from_saved_model
100    converter_kwargs["saved_model_dir"] = flags.saved_model_dir
101    converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
102    converter_kwargs["signature_key"] = flags.saved_model_signature_key
103  elif flags.keras_model_file:
104    converter_fn = lite.TFLiteConverter.from_keras_model_file
105    converter_kwargs["model_file"] = flags.keras_model_file
106  else:
107    raise ValueError("--graph_def_file, --saved_model_dir, or "
108                     "--keras_model_file must be specified.")
109
110  return converter_fn(**converter_kwargs)
111
112
113def _convert_model(flags):
114  """Calls function to convert the TensorFlow model into a TFLite model.
115
116  Args:
117    flags: argparse.Namespace object.
118
119  Raises:
120    ValueError: Invalid flags.
121  """
122  # Create converter.
123  converter = _get_toco_converter(flags)
124  if flags.inference_type:
125    converter.inference_type = _parse_inference_type(flags.inference_type,
126                                                     "inference_type")
127  if flags.inference_input_type:
128    converter.inference_input_type = _parse_inference_type(
129        flags.inference_input_type, "inference_input_type")
130  if flags.output_format:
131    converter.output_format = _toco_flags_pb2.FileFormat.Value(
132        flags.output_format)
133
134  if flags.mean_values and flags.std_dev_values:
135    input_arrays = converter.get_input_arrays()
136    std_dev_values = _parse_array(flags.std_dev_values, type_fn=float)
137
138    # In quantized inference, mean_value has to be integer so that the real
139    # value 0.0 is exactly representable.
140    if converter.inference_type == lite_constants.QUANTIZED_UINT8:
141      mean_values = _parse_array(flags.mean_values, type_fn=int)
142    else:
143      mean_values = _parse_array(flags.mean_values, type_fn=float)
144    quant_stats = list(zip(mean_values, std_dev_values))
145    if ((not flags.input_arrays and len(input_arrays) > 1) or
146        (len(input_arrays) != len(quant_stats))):
147      raise ValueError("Mismatching --input_arrays, --std_dev_values, and "
148                       "--mean_values. The flags must have the same number of "
149                       "items. The current input arrays are '{0}'. "
150                       "--input_arrays must be present when specifying "
151                       "--std_dev_values and --mean_values with multiple input "
152                       "tensors in order to map between names and "
153                       "values.".format(",".join(input_arrays)))
154    converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
155  if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
156                                                 not None):
157    converter.default_ranges_stats = (flags.default_ranges_min,
158                                      flags.default_ranges_max)
159
160  if flags.drop_control_dependency:
161    converter.drop_control_dependency = flags.drop_control_dependency
162  if flags.reorder_across_fake_quant:
163    converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
164  if flags.change_concat_input_ranges:
165    converter.change_concat_input_ranges = (
166        flags.change_concat_input_ranges == "TRUE")
167
168  if flags.allow_custom_ops:
169    converter.allow_custom_ops = flags.allow_custom_ops
170  if flags.target_ops:
171    ops_set_options = lite.OpsSet.get_options()
172    converter.target_ops = set()
173    for option in flags.target_ops.split(","):
174      if option not in ops_set_options:
175        raise ValueError("Invalid value for --target_ops. Options: "
176                         "{0}".format(",".join(ops_set_options)))
177      converter.target_ops.add(lite.OpsSet(option))
178
179  if flags.post_training_quantize:
180    converter.post_training_quantize = flags.post_training_quantize
181    if converter.inference_type == lite_constants.QUANTIZED_UINT8:
182      print("--post_training_quantize quantizes a graph of inference_type "
183            "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.")
184      converter.inference_type = lite_constants.FLOAT
185
186  if flags.dump_graphviz_dir:
187    converter.dump_graphviz_dir = flags.dump_graphviz_dir
188  if flags.dump_graphviz_video:
189    converter.dump_graphviz_vode = flags.dump_graphviz_video
190
191  # Convert model.
192  output_data = converter.convert()
193  with open(flags.output_file, "wb") as f:
194    f.write(output_data)
195
196
197def _check_flags(flags, unparsed):
198  """Checks the parsed and unparsed flags to ensure they are valid.
199
200  Raises an error if previously support unparsed flags are found. Raises an
201  error for parsed flags that don't meet the required conditions.
202
203  Args:
204    flags: argparse.Namespace object containing TFLite flags.
205    unparsed: List of unparsed flags.
206
207  Raises:
208    ValueError: Invalid flags.
209  """
210
211  # Check unparsed flags for common mistakes based on previous TOCO.
212  def _get_message_unparsed(flag, orig_flag, new_flag):
213    if flag.startswith(orig_flag):
214      return "\n  Use {0} instead of {1}".format(new_flag, orig_flag)
215    return ""
216
217  if unparsed:
218    output = ""
219    for flag in unparsed:
220      output += _get_message_unparsed(flag, "--input_file", "--graph_def_file")
221      output += _get_message_unparsed(flag, "--savedmodel_directory",
222                                      "--saved_model_dir")
223      output += _get_message_unparsed(flag, "--std_value", "--std_dev_values")
224      output += _get_message_unparsed(flag, "--batch_size", "--input_shapes")
225      output += _get_message_unparsed(flag, "--dump_graphviz",
226                                      "--dump_graphviz_dir")
227    if output:
228      raise ValueError(output)
229
230  # Check that flags are valid.
231  if flags.graph_def_file and (not flags.input_arrays or
232                               not flags.output_arrays):
233    raise ValueError("--input_arrays and --output_arrays are required with "
234                     "--graph_def_file")
235
236  if flags.input_shapes:
237    if not flags.input_arrays:
238      raise ValueError("--input_shapes must be used with --input_arrays")
239    if flags.input_shapes.count(":") != flags.input_arrays.count(","):
240      raise ValueError("--input_shapes and --input_arrays must have the same "
241                       "number of items")
242
243  if flags.std_dev_values or flags.mean_values:
244    if bool(flags.std_dev_values) != bool(flags.mean_values):
245      raise ValueError("--std_dev_values and --mean_values must be used "
246                       "together")
247    if flags.std_dev_values.count(",") != flags.mean_values.count(","):
248      raise ValueError("--std_dev_values, --mean_values must have the same "
249                       "number of items")
250
251  if (flags.default_ranges_min is None) != (flags.default_ranges_max is None):
252    raise ValueError("--default_ranges_min and --default_ranges_max must be "
253                     "used together")
254
255  if flags.dump_graphviz_video and not flags.dump_graphviz_dir:
256    raise ValueError("--dump_graphviz_video must be used with "
257                     "--dump_graphviz_dir")
258
259
260def run_main(_):
261  """Main in toco_convert.py."""
262  if tf2.enabled():
263    raise ValueError("tflite_convert is currently unsupported in 2.0. "
264                     "Please use the Python API "
265                     "tf.lite.TFLiteConverter.from_concrete_function().")
266
267  parser = argparse.ArgumentParser(
268      description=("Command line tool to run TensorFlow Lite Optimizing "
269                   "Converter (TOCO)."))
270
271  # Output file flag.
272  parser.add_argument(
273      "--output_file",
274      type=str,
275      help="Full filepath of the output file.",
276      required=True)
277
278  # Input file flags.
279  input_file_group = parser.add_mutually_exclusive_group(required=True)
280  input_file_group.add_argument(
281      "--graph_def_file",
282      type=str,
283      help="Full filepath of file containing frozen TensorFlow GraphDef.")
284  input_file_group.add_argument(
285      "--saved_model_dir",
286      type=str,
287      help="Full filepath of directory containing the SavedModel.")
288  input_file_group.add_argument(
289      "--keras_model_file",
290      type=str,
291      help="Full filepath of HDF5 file containing tf.Keras model.")
292
293  # Model format flags.
294  parser.add_argument(
295      "--output_format",
296      type=str.upper,
297      choices=["TFLITE", "GRAPHVIZ_DOT"],
298      help="Output file format.")
299  parser.add_argument(
300      "--inference_type",
301      type=str.upper,
302      choices=["FLOAT", "QUANTIZED_UINT8"],
303      help="Target data type of real-number arrays in the output file.")
304  parser.add_argument(
305      "--inference_input_type",
306      type=str.upper,
307      choices=["FLOAT", "QUANTIZED_UINT8"],
308      help=("Target data type of real-number input arrays. Allows for a "
309            "different type for input arrays in the case of quantization."))
310
311  # Input and output arrays flags.
312  parser.add_argument(
313      "--input_arrays",
314      type=str,
315      help="Names of the input arrays, comma-separated.")
316  parser.add_argument(
317      "--input_shapes",
318      type=str,
319      help="Shapes corresponding to --input_arrays, colon-separated.")
320  parser.add_argument(
321      "--output_arrays",
322      type=str,
323      help="Names of the output arrays, comma-separated.")
324
325  # SavedModel related flags.
326  parser.add_argument(
327      "--saved_model_tag_set",
328      type=str,
329      help=("Comma-separated set of tags identifying the MetaGraphDef within "
330            "the SavedModel to analyze. All tags must be present. In order to "
331            "pass in an empty tag set, pass in \"\". (default \"serve\")"))
332  parser.add_argument(
333      "--saved_model_signature_key",
334      type=str,
335      help=("Key identifying the SignatureDef containing inputs and outputs. "
336            "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
337
338  # Quantization flags.
339  parser.add_argument(
340      "--std_dev_values",
341      type=str,
342      help=("Standard deviation of training data for each input tensor, "
343            "comma-separated floats. Used for quantized input tensors. "
344            "(default None)"))
345  parser.add_argument(
346      "--mean_values",
347      type=str,
348      help=("Mean of training data for each input tensor, comma-separated "
349            "floats. Used for quantized input tensors. (default None)"))
350  parser.add_argument(
351      "--default_ranges_min",
352      type=float,
353      help=("Default value for min bound of min/max range values used for all "
354            "arrays without a specified range, Intended for experimenting with "
355            "quantization via \"dummy quantization\". (default None)"))
356  parser.add_argument(
357      "--default_ranges_max",
358      type=float,
359      help=("Default value for max bound of min/max range values used for all "
360            "arrays without a specified range, Intended for experimenting with "
361            "quantization via \"dummy quantization\". (default None)"))
362  # quantize_weights is DEPRECATED.
363  parser.add_argument(
364      "--quantize_weights",
365      dest="post_training_quantize",
366      action="store_true",
367      help=argparse.SUPPRESS)
368  parser.add_argument(
369      "--post_training_quantize",
370      dest="post_training_quantize",
371      action="store_true",
372      help=(
373          "Boolean indicating whether to quantize the weights of the "
374          "converted float model. Model size will be reduced and there will "
375          "be latency improvements (at the cost of accuracy). (default False)"))
376
377  # Graph manipulation flags.
378  parser.add_argument(
379      "--drop_control_dependency",
380      action="store_true",
381      help=("Boolean indicating whether to drop control dependencies silently. "
382            "This is due to TensorFlow not supporting control dependencies. "
383            "(default True)"))
384  parser.add_argument(
385      "--reorder_across_fake_quant",
386      action="store_true",
387      help=("Boolean indicating whether to reorder FakeQuant nodes in "
388            "unexpected locations. Used when the location of the FakeQuant "
389            "nodes is preventing graph transformations necessary to convert "
390            "the graph. Results in a graph that differs from the quantized "
391            "training graph, potentially causing differing arithmetic "
392            "behavior. (default False)"))
393  # Usage for this flag is --change_concat_input_ranges=true or
394  # --change_concat_input_ranges=false in order to make it clear what the flag
395  # is set to. This keeps the usage consistent with other usages of the flag
396  # where the default is different. The default value here is False.
397  parser.add_argument(
398      "--change_concat_input_ranges",
399      type=str.upper,
400      choices=["TRUE", "FALSE"],
401      help=("Boolean to change behavior of min/max ranges for inputs and "
402            "outputs of the concat operator for quantized models. Changes the "
403            "ranges of concat operator overlap when true. (default False)"))
404
405  # Permitted ops flags.
406  parser.add_argument(
407      "--allow_custom_ops",
408      action="store_true",
409      help=("Boolean indicating whether to allow custom operations. When false "
410            "any unknown operation is an error. When true, custom ops are "
411            "created for any op that is unknown. The developer will need to "
412            "provide these to the TensorFlow Lite runtime with a custom "
413            "resolver. (default False)"))
414  parser.add_argument(
415      "--target_ops",
416      type=str,
417      help=("Experimental flag, subject to change. Set of OpsSet options "
418            "indicating which converter to use. Options: {0}. One or more "
419            "option may be specified. (default set([OpsSet.TFLITE_BUILTINS]))"
420            "".format(",".join(lite.OpsSet.get_options()))))
421
422  # Logging flags.
423  parser.add_argument(
424      "--dump_graphviz_dir",
425      type=str,
426      help=("Full filepath of folder to dump the graphs at various stages of "
427            "processing GraphViz .dot files. Preferred over --output_format="
428            "GRAPHVIZ_DOT in order to keep the requirements of the output "
429            "file."))
430  parser.add_argument(
431      "--dump_graphviz_video",
432      action="store_true",
433      help=("Boolean indicating whether to dump the graph after every graph "
434            "transformation"))
435
436  tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
437  try:
438    _check_flags(tflite_flags, unparsed)
439  except ValueError as e:
440    parser.print_usage()
441    file_name = os.path.basename(sys.argv[0])
442    sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e)))
443    sys.exit(1)
444  _convert_model(tflite_flags)
445
446
447def main():
448  app.run(main=run_main, argv=sys.argv[:1])
449
450
451if __name__ == "__main__":
452  main()
453