1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Converts a frozen graph into a TFLite FlatBuffer."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import distutils.spawn
23import enum  # pylint: disable=g-bad-import-order
24import os as _os
25import platform as _platform
26import subprocess as _subprocess
27import tempfile as _tempfile
28
29import six
30from six.moves import map
31
32from tensorflow.lite.python import lite_constants
33from tensorflow.lite.python import util
34from tensorflow.lite.python import wrap_toco
35from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
36from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
37from tensorflow.lite.toco import types_pb2 as _types_pb2
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import tensor_shape
40from tensorflow.python.platform import resource_loader as _resource_loader
41from tensorflow.python.util import deprecation
42from tensorflow.python.util.tf_export import tf_export as _tf_export
43
44_quantized_inference_types = [_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8]
45
46
47# If the `inference_type` or the `inference_input_type` is the quantized type
48# and it is not post training quantization, the input quantization stats is
49# required.
50def _requires_input_stats(toco_flags):
51  return ((toco_flags.inference_type in _quantized_inference_types or
52           toco_flags.inference_input_type in _quantized_inference_types) and
53          not toco_flags.post_training_quantize)
54
55# Find the toco_from_protos binary using the resource loader if using from
56# bazel, otherwise we are in a pip where console_scripts already has
57# the toco_from_protos tool.
58if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY:
59  _toco_from_proto_bin = ""
60else:
61  _toco_from_proto_bin = _resource_loader.get_path_to_datafile(
62      "../toco/python/toco_from_protos")
63
64if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin):
65  _toco_from_proto_bin = "toco_from_protos"
66
67
68def _try_convert_to_unicode(output):
69  if output is None:
70    return u""
71
72  if isinstance(output, bytes):
73    try:
74      return six.ensure_text(output)
75    except UnicodeDecodeError:
76      pass
77  return output
78
79
80@_tf_export("lite.OpsSet")
81class OpsSet(enum.Enum):
82  """Enum class defining the sets of ops available to generate TFLite models.
83
84  WARNING: Experimental interface, subject to change.
85  """
86  # Convert model using TensorFlow Lite builtin ops.
87  TFLITE_BUILTINS = "TFLITE_BUILTINS"
88
89  # Convert model using TensorFlow ops. Not all TensorFlow ops are available.
90  # WARNING: Experimental interface, subject to change.
91  SELECT_TF_OPS = "SELECT_TF_OPS"
92
93  # Convert model using only TensorFlow Lite quantized int8 operations.
94  # Specifying this will throw an error for operations that do not yet have
95  # quantized implementations.
96  TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8"
97
98  # Convert model using only TensorFlow Lite operations with quantized int8
99  # weights, int16 activations and int64 bias.
100  # Specifying this will throw an error for operations that do not yet have
101  # quantized implementations.
102  # This quantization mode may be used in models for super-resolution,
103  # audio signal processing or image de-noising. It improves accuracy
104  # significantly, but only slightly increases the model size.
105  # WARNING: These ops are currently experimental and have not yet been
106  # finalized.
107  # They are only compatible with CPU execution, and have not been optimized for
108  # production.
109  EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = \
110    "EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
111
112  def __str__(self):
113    return str(self.value)
114
115  @staticmethod
116  def get_options():
117    """Returns a list of OpsSet options as a list of strings."""
118    return [str(option) for option in list(OpsSet)]
119
120
121class ConverterError(Exception):
122  """Raised when an error occurs during model conversion."""
123  pass
124
125
126def mlir_quantize(input_data_str,
127                  disable_per_channel=False,
128                  fully_quantize=False,
129                  inference_type=_types_pb2.INT8,
130                  enable_numeric_verify=False):
131  """Quantize `input_data_str` with calibration results.
132
133  Args:
134    input_data_str: Input data in serialized form (e.g. a TFLITE model with
135      calibration results).
136    disable_per_channel: Bool indicating whether to do per-channel or per-tensor
137      quantization
138    fully_quantize: Bool indicating whether to fully quantize the model. Besides
139      model body, the input/output will be quantized as well.
140    inference_type: Data type for the activations. The default value is int8.
141    enable_numeric_verify: Experimental. Subject to change. Bool indicating
142      whether to add NumericVerify ops into the debug mode quantized model.
143
144  Returns:
145    Quantized model in serialized form (e.g. a TFLITE model) with floating-point
146    inputs and outputs.
147  """
148  return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
149                                                      disable_per_channel,
150                                                      fully_quantize,
151                                                      inference_type,
152                                                      enable_numeric_verify)
153
154
155def mlir_sparsify(input_data_str):
156  """Sparsify `input_data_str` to encode sparse tensor with proper format.
157
158  Args:
159    input_data_str: Input data in serialized form (e.g. a TFLITE model).
160
161  Returns:
162    Sparsified model in serialized form (e.g. a TFLITE model).
163  """
164  return wrap_toco.wrapped_experimental_mlir_sparsify(input_data_str)
165
166
167def register_custom_opdefs(custom_opdefs_list):
168  """Register the given custom opdefs to the TensorFlow global op registry.
169
170  Args:
171    custom_opdefs_list: String representing the custom ops OpDefs that are
172      included in the GraphDef.
173
174  Returns:
175    True if the registration is successfully completed.
176  """
177  return wrap_toco.wrapped_register_custom_opdefs(custom_opdefs_list)
178
179
180def toco_convert_protos(model_flags_str,
181                        toco_flags_str,
182                        input_data_str,
183                        debug_info_str=None,
184                        enable_mlir_converter=False):
185  """Convert `input_data_str` according to model and toco parameters.
186
187  Unless you know what you are doing consider using
188  the more friendly `tf.compat.v1.lite.toco_convert`.
189
190  Args:
191    model_flags_str: Serialized proto describing model properties, see
192      `toco/model_flags.proto`.
193    toco_flags_str: Serialized proto describing conversion properties, see
194      `toco/toco_flags.proto`.
195    input_data_str: Input data in serialized form (e.g. a graphdef is common)
196    debug_info_str: Serialized `GraphDebugInfo` proto describing logging
197      information. (default None)
198    enable_mlir_converter: Enables MLIR-based conversion instead of the default
199      TOCO conversion. (default False)
200
201  Returns:
202    Converted model in serialized form (e.g. a TFLITE model is common).
203  Raises:
204    ConverterError: When conversion fails in TFLiteConverter, usually due to
205      ops not being supported.
206    RuntimeError: When conversion fails, an exception is raised with the error
207      message embedded.
208  """
209  # Historically, TOCO conversion failures would trigger a crash, so we would
210  # attempt to run the converter out-of-process. The MLIR conversion pipeline
211  # surfaces errors instead, and can be safely run in-process.
212  if enable_mlir_converter or not _toco_from_proto_bin:
213    try:
214      model_str = wrap_toco.wrapped_toco_convert(model_flags_str,
215                                                 toco_flags_str, input_data_str,
216                                                 debug_info_str,
217                                                 enable_mlir_converter)
218      return model_str
219    except Exception as e:
220      raise ConverterError(str(e))
221
222  if distutils.spawn.find_executable(_toco_from_proto_bin) is None:
223    raise ConverterError("""Could not find toco_from_protos binary, make sure
224your virtualenv bin directory or pip local bin directory is in your path.
225In particular, if you have installed TensorFlow with --user, make sure you
226add the install directory to your path.
227
228For example:
229Linux: export PATH=$PATH:~/.local/bin/
230Mac: export PATH=$PATH:~/Library/Python/<version#>/bin
231
232Alternative, use virtualenv.""")
233  # Windows and TemporaryFile are not that useful together,
234  # since you cannot have two readers/writers. So we have to
235  # make the temporaries and close and delete them explicitly.
236  toco_filename, model_filename, input_filename, output_filename = (None, None,
237                                                                    None, None)
238  try:
239    # Build all input files
240    with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \
241             _tempfile.NamedTemporaryFile(delete=False) as fp_model, \
242             _tempfile.NamedTemporaryFile(delete=False) as fp_input, \
243             _tempfile.NamedTemporaryFile(delete=False) as fp_debug:
244      toco_filename = fp_toco.name
245      input_filename = fp_input.name
246      model_filename = fp_model.name
247      debug_filename = fp_debug.name
248
249      fp_model.write(model_flags_str)
250      fp_toco.write(toco_flags_str)
251      fp_input.write(six.ensure_binary(input_data_str))
252      debug_info_str = debug_info_str if debug_info_str else ""
253      # if debug_info_str contains a "string value", then the call to
254      # fp_debug.write(debug_info_str) will fail with the following error
255      #
256      # TypeError: a bytes-like object is required, not 'str'
257      #
258      # Some of the subtests within the "convert_test" unit-test fail
259      # with the error shown above. So watch out for that scenario and
260      # convert debug_info_str to bytes where needed
261      if not isinstance(debug_info_str, bytes):
262        fp_debug.write(debug_info_str.encode("utf-8"))
263      else:
264        fp_debug.write(debug_info_str)
265
266    # Reserve an output file
267    with _tempfile.NamedTemporaryFile(delete=False) as fp:
268      output_filename = fp.name
269
270    # Run
271    cmd = [
272        _toco_from_proto_bin,
273        model_filename,
274        toco_filename,
275        input_filename,
276        output_filename,
277        "--debug_proto_file={}".format(debug_filename),
278    ]
279    if enable_mlir_converter:
280      cmd.append("--enable_mlir_converter")
281    cmdline = " ".join(cmd)
282    is_windows = _platform.system() == "Windows"
283    proc = _subprocess.Popen(
284        cmdline,
285        shell=True,
286        stdout=_subprocess.PIPE,
287        stderr=_subprocess.STDOUT,
288        close_fds=not is_windows)
289    stdout, stderr = proc.communicate()
290    exitcode = proc.returncode
291    if exitcode == 0:
292      with open(output_filename, "rb") as fp:
293        return fp.read()
294    else:
295      stdout = _try_convert_to_unicode(stdout)
296      stderr = _try_convert_to_unicode(stderr)
297      raise ConverterError("See console for info.\n%s\n%s\n" % (stdout, stderr))
298  finally:
299    # Must manually cleanup files.
300    for filename in [
301        toco_filename, input_filename, model_filename, output_filename
302    ]:
303      try:
304        _os.unlink(filename)
305      except (OSError, TypeError):
306        pass
307
308
309def build_toco_flags(inference_type=dtypes.float32,
310                     inference_input_type=None,
311                     input_format=lite_constants.TENSORFLOW_GRAPHDEF,
312                     output_format=lite_constants.TFLITE,
313                     default_ranges_stats=None,
314                     drop_control_dependency=True,
315                     reorder_across_fake_quant=False,
316                     allow_custom_ops=False,
317                     custom_opdefs=None,
318                     post_training_quantize=False,
319                     quantize_to_float16=False,
320                     dump_graphviz_dir=None,
321                     dump_graphviz_video=False,
322                     target_ops=None,
323                     conversion_summary_dir=None,
324                     select_user_tf_ops=None,
325                     enable_tflite_resource_variables=False,
326                     **_):
327  """Build the TOCO flags object from params."""
328  toco = _toco_flags_pb2.TocoFlags()
329  toco.input_format = input_format
330  toco.output_format = output_format
331  toco.inference_type = util.convert_dtype_to_tflite_type(inference_type)
332  if inference_input_type:
333    toco.inference_input_type = util.convert_dtype_to_tflite_type(
334        inference_input_type)
335  else:
336    toco.inference_input_type = toco.inference_type
337  toco.drop_control_dependency = drop_control_dependency
338  toco.reorder_across_fake_quant = reorder_across_fake_quant
339  toco.allow_custom_ops = allow_custom_ops
340  if custom_opdefs:
341    toco.custom_opdefs.extend(custom_opdefs)
342  if select_user_tf_ops:
343    toco.select_user_tf_ops.extend(select_user_tf_ops)
344  toco.post_training_quantize = post_training_quantize
345  toco.quantize_to_float16 = quantize_to_float16
346  if default_ranges_stats:
347    toco.default_ranges_min = default_ranges_stats[0]
348    toco.default_ranges_max = default_ranges_stats[1]
349  if dump_graphviz_dir:
350    toco.dump_graphviz_dir = dump_graphviz_dir
351  toco.dump_graphviz_include_video = dump_graphviz_video
352  if conversion_summary_dir:
353    toco.conversion_summary_dir = conversion_summary_dir
354  if target_ops:
355    if OpsSet.SELECT_TF_OPS in set(target_ops):
356      toco.enable_select_tf_ops = True
357    if set(target_ops) == set([OpsSet.SELECT_TF_OPS]):
358      toco.force_select_tf_ops = True
359  toco.enable_tflite_resource_variables = enable_tflite_resource_variables
360  return toco
361
362
363def build_toco_convert_protos(input_tensors,
364                              output_tensors,
365                              inference_type=dtypes.float32,
366                              inference_input_type=None,
367                              input_format=lite_constants.TENSORFLOW_GRAPHDEF,
368                              input_shapes=None,
369                              output_format=lite_constants.TFLITE,
370                              quantized_input_stats=None,
371                              default_ranges_stats=None,
372                              drop_control_dependency=True,
373                              reorder_across_fake_quant=False,
374                              allow_custom_ops=False,
375                              custom_opdefs=None,
376                              change_concat_input_ranges=False,
377                              post_training_quantize=False,
378                              quantize_to_float16=False,
379                              dump_graphviz_dir=None,
380                              dump_graphviz_video=False,
381                              target_ops=None,
382                              allow_nonexistent_arrays=False,
383                              debug_info=None,
384                              conversion_summary_dir=None,
385                              saved_model_dir=None,
386                              saved_model_version=0,
387                              saved_model_tags=None,
388                              saved_model_exported_names=None,
389                              select_user_tf_ops=None):
390  """Builds protocol buffers describing a conversion of a model using TOCO.
391
392  Typically this is to convert from TensorFlow GraphDef to TFLite, in which
393  case the default `input_format` and `output_format` are sufficient.
394
395  Args:
396    input_tensors: List of input tensors. Type and shape are computed using
397      `foo.shape` and `foo.dtype`.
398    output_tensors: List of output tensors (only .name is used from this).
399    inference_type: Data type of numeric arrays, excluding the input layer.
400      (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
401    inference_input_type: Data type of the numeric arrays in the input layer. If
402      `inference_input_type` is in {tf.int8, tf.uint8}, then
403      `quantized_input_stats` must be provided. (default is the value assigned
404      to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
405    input_format: Type of data to read.
406      (default TENSORFLOW_GRAPHDEF, must be in {TENSORFLOW_GRAPHDEF})
407    input_shapes: Input array shape. (default None, must be None or a list of
408      the same length as `input_tensors`.)
409    output_format: Output file format. (default TFLITE, must be in
410    {TFLITE, GRAPHVIZ_DOT})
411    quantized_input_stats: Map of input tensor names to a tuple of floats
412      representing the mean and standard deviation of the training data.
413      (e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
414        or tf.uint8. (default None)
415    default_ranges_stats: Tuple of integers representing (min, max) range values
416      for all arrays without a specified range. Intended for experimenting with
417      quantization via "dummy quantization". (default None)
418    drop_control_dependency: Boolean indicating whether to drop control
419      dependencies silently. This is due to TFLite not supporting control
420      dependencies. (default True)
421    reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
422      nodes in unexpected locations. Used when the location of the FakeQuant
423      nodes is preventing graph transformations necessary to convert the graph.
424      Results in a graph that differs from the quantized training graph,
425      potentially causing differing arithmetic behavior. (default False)
426    allow_custom_ops: Boolean indicating whether to allow custom operations.
427      When false any unknown operation is an error. When true, custom ops are
428      created for any op that is unknown. The developer will need to provide
429      these to the TensorFlow Lite runtime with a custom resolver. (default
430      False)
431    custom_opdefs: List of strings representing custom ops OpDefs that are
432      included in the GraphDef. Required when using custom operations with the
433      MLIR-based converter. (default None)
434    change_concat_input_ranges: Boolean to change behavior of min/max ranges for
435      inputs and outputs of the concat operator for quantized models. Changes
436      the ranges of concat operator overlap when true. (default False)
437    post_training_quantize: Boolean indicating whether to quantize the weights
438      of the converted float model. Model size will be reduced and there will be
439      latency improvements (at the cost of accuracy). (default False)
440    quantize_to_float16: Boolean indicating whether to convert float buffers to
441      float16. (default False)
442    dump_graphviz_dir: Full filepath of folder to dump the graphs at various
443      stages of processing GraphViz .dot files. Preferred over
444      --output_format=GRAPHVIZ_DOT in order to keep the requirements of the
445      output file. (default None)
446    dump_graphviz_video: Boolean indicating whether to dump the graph after
447      every graph transformation. (default False)
448    target_ops: Experimental flag, subject to change. Set of OpsSet options
449      indicating which converter to use. (default set([OpsSet.TFLITE_BUILTINS]))
450    allow_nonexistent_arrays: Allow specifying array names that don't exist or
451      are unused in the final graph. (default False)
452    debug_info: `GraphDebugInfo` proto containing the stack traces for the
453      original nodes referred by the converted graph.
454    conversion_summary_dir: A string, the path to the generated conversion logs.
455    saved_model_dir: Filepath of the saved model to be converted. This value
456      will be non-empty only when the saved model import path will be used.
457      Otherwises, the graph def-based conversion will be processed.
458    saved_model_version: SavedModel file format version of The saved model file
459      to be converted. This value will be set only when the SavedModel import
460      path will be used.
461    saved_model_tags: Set of string saved model tags, formatted in the
462      comma-separated value. This value will be set only when the SavedModel
463      import path will be used.
464    saved_model_exported_names: Names to be exported (default: export all) when
465      the saved model import path is on. This value will be set only when the
466      SavedModel import path will be used.
467    select_user_tf_ops: List of user's defined TensorFlow ops need to be
468      supported in the TensorFlow Lite runtime. These ops will be supported as
469      select TensorFlow ops.
470
471  Returns:
472    model_flags, toco_flags, debug_info: three protocol buffers describing the
473      conversion process and debug information.
474
475  Raises:
476    ValueError:
477      If the input tensor type is unknown
478      Missing mean_values or std_dev_values
479    RuntimeError: If TOCO fails to convert (in which case the runtime error's
480      error text will contain the TOCO error log)
481  """
482  toco = build_toco_flags(inference_type, inference_input_type, input_format,
483                          output_format, default_ranges_stats,
484                          drop_control_dependency, reorder_across_fake_quant,
485                          allow_custom_ops, custom_opdefs,
486                          post_training_quantize, quantize_to_float16,
487                          dump_graphviz_dir, dump_graphviz_video, target_ops,
488                          conversion_summary_dir, select_user_tf_ops)
489  model = _model_flags_pb2.ModelFlags()
490  model.change_concat_input_ranges = change_concat_input_ranges
491  for idx, input_tensor in enumerate(input_tensors):
492    input_array = model.input_arrays.add()
493    if saved_model_dir:
494      input_array.name = input_tensor.name
495    else:
496      input_array.name = util.get_tensor_name(input_tensor)
497    input_array.data_type = util.convert_dtype_to_tflite_type(
498        input_tensor.dtype)
499
500    if _requires_input_stats(toco) and quantized_input_stats:
501      input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
502
503    if input_shapes is None:
504      shape = input_tensor.shape
505    else:
506      shape = input_shapes[idx]
507
508    if shape.rank is not None:
509      # Create shapes with -1 for unknown dimensions.
510      dims = []
511      for dim in shape:
512        if (dim is None or
513            (isinstance(dim, tensor_shape.Dimension) and dim.value is None)):
514          dims.append(-1)
515        else:
516          dims.append(int(dim))
517      input_array.shape.dims.extend(dims)
518      input_array.shape.unknown_rank = False
519    else:
520      input_array.shape.unknown_rank = True
521
522  for output_tensor in output_tensors:
523    if saved_model_dir:
524      model.output_arrays.append(output_tensor.name)
525    else:
526      model.output_arrays.append(util.get_tensor_name(output_tensor))
527
528  model.allow_nonexistent_arrays = allow_nonexistent_arrays
529
530  if saved_model_dir:
531    model.saved_model_dir = saved_model_dir
532  model.saved_model_version = saved_model_version
533  if saved_model_tags:
534    model.saved_model_tags.extend(saved_model_tags)
535  if saved_model_exported_names:
536    model.saved_model_exported_names.extend(saved_model_exported_names)
537
538  return model, toco, debug_info
539
540
541def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
542                           enable_mlir_converter, *args, **kwargs):
543  """"Convert a model using TOCO.
544
545  This function is used to convert GraphDefs that cannot be loaded into
546  TensorFlow to TFLite. Conversion can be customized by providing arguments
547  that are forwarded to `build_toco_convert_protos` (see documentation for
548  details).
549
550  Args:
551    input_data: Input data (i.e. often `sess.graph_def`),
552    input_arrays_with_shape: Tuple of strings representing input tensor names
553      and list of integers representing input shapes
554      (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
555        into TensorFlow and when `input_tensors` is None. (default None)
556    output_arrays: List of output tensors to freeze graph with. Use only when
557      graph cannot be loaded into TensorFlow and when `output_tensors` is None.
558      (default None)
559    enable_mlir_converter: Enables MLIR-based conversion instead of TOCO
560      conversion.
561    *args: See `build_toco_convert_protos`,
562    **kwargs: See `build_toco_convert_protos`.
563
564  Returns:
565    The converted data. For example if TFLite was the destination, then
566    this will be a tflite flatbuffer in a bytes array.
567
568  Raises:
569    Defined in `build_toco_convert_protos`.
570  """
571  model_flags, toco_flags, _ = build_toco_convert_protos(
572      input_tensors=[], output_tensors=[], *args, **kwargs)
573
574  for idx, (name, shape) in enumerate(input_arrays_with_shape):
575    input_array = model_flags.input_arrays.add()
576    if _requires_input_stats(toco_flags):
577      if (("quantized_input_stats" not in kwargs) or
578          (not kwargs["quantized_input_stats"])):
579        raise ValueError(
580            "The `quantized_input_stats` flag must be defined when either "
581            "`inference_type` flag or `inference_input_type` flag is set to "
582            "tf.int8 or tf.uint8.")
583      input_array.mean_value, input_array.std_value = kwargs[
584          "quantized_input_stats"][idx]
585    input_array.name = name
586    input_array.shape.dims.extend(list(map(int, shape)))
587
588  for name in output_arrays:
589    model_flags.output_arrays.append(name)
590
591  data = toco_convert_protos(
592      model_flags.SerializeToString(),
593      toco_flags.SerializeToString(),
594      input_data.SerializeToString(),
595      enable_mlir_converter=enable_mlir_converter)
596  return data
597
598
599def toco_convert_impl(input_data, input_tensors, output_tensors,
600                      enable_mlir_converter, *args, **kwargs):
601  """"Convert a model using TOCO.
602
603  Typically this function is used to convert from TensorFlow GraphDef to TFLite.
604  Conversion can be customized by providing arguments that are forwarded to
605  `build_toco_convert_protos` (see documentation for details).
606
607  Args:
608    input_data: Input data (i.e. often `sess.graph_def`),
609    input_tensors: List of input tensors. Type and shape are computed using
610      `foo.shape` and `foo.dtype`.
611    output_tensors: List of output tensors (only .name is used from this).
612    enable_mlir_converter: Enables MLIR-based conversion instead of TOCO
613      conversion.
614    *args: See `build_toco_convert_protos`,
615    **kwargs: See `build_toco_convert_protos`.
616
617  Returns:
618    The converted data. For example if TFLite was the destination, then
619    this will be a tflite flatbuffer in a bytes array.
620
621  Raises:
622    Defined in `build_toco_convert_protos`.
623  """
624  model_flags, toco_flags, debug_info = build_toco_convert_protos(
625      input_tensors, output_tensors, *args, **kwargs)
626  debug_info_str = debug_info.SerializeToString() if debug_info else None
627  data = toco_convert_protos(
628      model_flags.SerializeToString(),
629      toco_flags.SerializeToString(),
630      input_data.SerializeToString(),
631      debug_info_str=debug_info_str,
632      enable_mlir_converter=enable_mlir_converter)
633  return data
634
635
636def convert_saved_model(saved_model_dir=None,
637                        saved_model_version=0,
638                        saved_model_tags=None,
639                        saved_model_exported_names=None,
640                        **kwargs):
641  """Converts a saved_model using TF Lite converter."""
642  model_flags = _model_flags_pb2.ModelFlags()
643  if saved_model_dir:
644    model_flags.saved_model_dir = saved_model_dir
645  model_flags.saved_model_version = saved_model_version
646  if saved_model_tags:
647    model_flags.saved_model_tags.extend(saved_model_tags)
648  if saved_model_exported_names:
649    model_flags.saved_model_exported_names.extend(saved_model_exported_names)
650  toco_flags = build_toco_flags(**kwargs)
651  data = toco_convert_protos(
652      model_flags.SerializeToString(),
653      toco_flags.SerializeToString(),
654      None,  # input_data, unused
655      None,  # debug_info_str, unused
656      enable_mlir_converter=True)
657  return data
658
659
660@_tf_export(v1=["lite.toco_convert"])
661@deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.")
662def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
663  """Convert a model using TOCO.
664
665  Typically this function is used to convert from TensorFlow GraphDef to TFLite.
666  Conversion can be customized by providing arguments that are forwarded to
667  `build_toco_convert_protos` (see documentation for details). This function has
668  been deprecated. Please use `tf.lite.TFLiteConverter` instead.
669
670  Args:
671    input_data: Input data (i.e. often `sess.graph_def`),
672    input_tensors: List of input tensors. Type and shape are computed using
673      `foo.shape` and `foo.dtype`.
674    output_tensors: List of output tensors (only .name is used from this).
675    *args: See `build_toco_convert_protos`,
676    **kwargs: See `build_toco_convert_protos`.
677
678  Returns:
679    The converted data. For example if TFLite was the destination, then
680    this will be a tflite flatbuffer in a bytes array.
681
682  Raises:
683    Defined in `build_toco_convert_protos`.
684  """
685  enable_mlir_converter = kwargs.get("enable_mlir_converter", False)
686  return toco_convert_impl(input_data, input_tensors, output_tensors,
687                           enable_mlir_converter, *args, **kwargs)
688