1# Lint as: python2, python3
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""TensorFlow Lite tooling helper functionality."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import enum
23import shutil
24import tempfile
25import warnings
26
27from absl import logging
28import six
29from six import PY2
30
31from google.protobuf import text_format as _text_format
32from google.protobuf.message import DecodeError
33from tensorflow.core.framework import graph_pb2 as _graph_pb2
34from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn  # pylint: disable=unused-import
35from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell  # pylint: disable=unused-import
36from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell  # pylint: disable=unused-import
37from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op  # pylint: disable=unused-import
38from tensorflow.lite.experimental.tensorboard.ops_util import get_potentially_supported_ops  # pylint: disable=unused-import
39from tensorflow.lite.python import lite_constants as constants
40from tensorflow.lite.python.convert import build_toco_convert_protos  # pylint: disable=unused-import
41from tensorflow.lite.python.convert import convert_saved_model as _convert_saved_model
42from tensorflow.lite.python.convert import ConverterError  # pylint: disable=unused-import
43from tensorflow.lite.python.convert import mlir_quantize as _mlir_quantize
44from tensorflow.lite.python.convert import mlir_sparsify as _mlir_sparsify
45from tensorflow.lite.python.convert import OpsSet
46from tensorflow.lite.python.convert import toco_convert  # pylint: disable=unused-import
47from tensorflow.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
48from tensorflow.lite.python.convert import toco_convert_impl as _toco_convert_impl
49from tensorflow.lite.python.convert import toco_convert_protos  # pylint: disable=unused-import
50from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
51from tensorflow.lite.python.interpreter import Interpreter  # pylint: disable=unused-import
52from tensorflow.lite.python.interpreter import load_delegate  # pylint: disable=unused-import
53from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs  # pylint: disable=unused-import
54from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted
55from tensorflow.lite.python.op_hint import OpHint  # pylint: disable=unused-import
56from tensorflow.lite.python.optimize import calibrator as _calibrator
57from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func
58from tensorflow.lite.python.util import convert_debug_info_func as _convert_debug_info_func
59from tensorflow.lite.python.util import freeze_graph as _freeze_graph
60from tensorflow.lite.python.util import get_debug_info as _get_debug_info
61from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config
62from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
63from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
64from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name
65from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
66from tensorflow.lite.python.util import model_input_signature as _model_input_signature
67from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
68from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
69from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
70from tensorflow.lite.python.util import trace_model_call as _trace_model_call
71from tensorflow.python.client import session as _session
72from tensorflow.python.eager import context
73from tensorflow.python.eager import def_function as _def_function
74from tensorflow.python.eager import function as _function
75from tensorflow.python.framework import convert_to_constants as _convert_to_constants
76from tensorflow.python.framework import dtypes as _dtypes
77from tensorflow.python.framework import ops as _ops
78from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
79from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
80from tensorflow.python.lib.io import file_io as _file_io
81from tensorflow.python.saved_model import loader_impl as _loader_impl
82from tensorflow.python.saved_model import signature_constants as _signature_constants
83from tensorflow.python.saved_model import tag_constants as _tag_constants
84from tensorflow.python.saved_model.load import load as _load
85from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
86from tensorflow.python.util import deprecation as _deprecation
87from tensorflow.python.util import keras_deps
88from tensorflow.python.util.tf_export import tf_export as _tf_export
89
90
91@_tf_export("lite.Optimize")
92class Optimize(enum.Enum):
93  """Enum defining the optimizations to apply when generating a tflite model.
94
95  DEFAULT
96      Default optimization strategy that quantizes model weights. Enhanced
97      optimizations are gained by providing a representative dataset that
98      quantizes biases and activations as well.
99      Converter will do its best to reduce size and latency, while minimizing
100      the loss in accuracy.
101
102  OPTIMIZE_FOR_SIZE
103      Deprecated. Does the same as DEFAULT.
104
105  OPTIMIZE_FOR_LATENCY
106      Deprecated. Does the same as DEFAULT.
107
108  EXPERIMENTAL_SPARSITY
109      Experimental flag, subject to change.
110
111      Enable optimization by taking advantage of the sparse model weights
112      trained with pruning.
113
114      The converter will inspect the sparsity pattern of the model weights and
115      do its best to improve size and latency.
116      The flag can be used alone to optimize float32 models with sparse weights.
117      It can also be used together with the DEFAULT optimization mode to
118      optimize quantized models with sparse weights.
119  """
120
121  # Default optimization strategy that quantizes model weights. Enhanced
122  # optimizations are gained by providing a representative dataset that
123  # quantizes biases and activations as well.
124  # Converter will do its best to reduce size and latency, while minimizing
125  # the loss in accuracy.
126  DEFAULT = "DEFAULT"
127
128  # Deprecated. Does the same as DEFAULT.
129  OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE"
130
131  # Deprecated. Does the same as DEFAULT.
132  OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY"
133
134  # Experimental flag, subject to change.
135  # Enable optimization by taking advantage of the sparse model weights trained
136  # with pruning.
137  #
138  # The converter will inspect the sparsity pattern of the model weights and do
139  # its best to improve size and latency.
140  # The flag can be used alone to optimize float32 models with sparse weights.
141  # It can also be used together with the DEFAULT optimization mode to optimize
142  # quantized models with sparse weights.
143  # TODO(b/161560631): Add log message when this optimization is applied.
144  EXPERIMENTAL_SPARSITY = "EXPERIMENTAL_SPARSITY"
145
146  def __str__(self):
147    return str(self.value)
148
149
150@_tf_export("lite.RepresentativeDataset")
151class RepresentativeDataset(object):
152  """Representative dataset used to optimize the model.
153
154  This is a generator function that provides a small dataset to calibrate or
155  estimate the range, i.e, (min, max) of all floating-point arrays in the model
156  (such as model input, activation outputs of intermediate layers, and model
157  output) for quantization. Usually, this is a small subset of a few hundred
158  samples randomly chosen, in no particular order, from the training or
159  evaluation dataset.
160  """
161
162  def __init__(self, input_gen):
163    """Creates a representative dataset.
164
165    Args:
166      input_gen: A generator function that generates input samples for the
167        model and has the same order, type and shape as the inputs to the model.
168        Usually, this is a small subset of a few hundred samples randomly
169        chosen, in no particular order, from the training or evaluation dataset.
170    """
171    self.input_gen = input_gen
172
173
174@_tf_export("lite.TargetSpec")
175class TargetSpec(object):
176  """Specification of target device used to optimize the model.
177
178  Attributes:
179    supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet`
180      options, where each option represents a set of operators supported by the
181      target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS}))
182    supported_types: Set of `tf.dtypes.DType` data types supported on the target
183      device. If initialized, optimization might be driven by the smallest type
184      in this set. (default set())
185    experimental_select_user_tf_ops: Experimental flag, subject to change. Set
186      of user's TensorFlow operators' names that are required in the TensorFlow
187      Lite runtime. These ops will be exported as select TensorFlow ops in the
188      model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is
189      an advanced feature that should only be used if the client is using TF ops
190      that may not be linked in by default with the TF ops that are provided
191      when using the SELECT_TF_OPS path. The client is responsible for linking
192      these ops into the target runtime.
193  """
194
195  def __init__(self,
196               supported_ops=None,
197               supported_types=None,
198               experimental_select_user_tf_ops=None):
199    if supported_ops is None:
200      supported_ops = {OpsSet.TFLITE_BUILTINS}
201    self.supported_ops = supported_ops
202    if supported_types is None:
203      supported_types = set()
204    self.supported_types = supported_types
205    if experimental_select_user_tf_ops is None:
206      self.experimental_select_user_tf_ops = set()
207
208
209class QuantizationMode(object):
210  """QuantizationMode determines the quantization type from user options."""
211
212  def __init__(self, optimizations, target_spec, representative_dataset,
213               graph_def):
214    self._optimizations = optimizations
215    self._target_spec = target_spec
216    self._representative_dataset = representative_dataset
217    self._graph_def = graph_def
218
219    self._validate_int8_required()
220
221  # TODO(b/162537905): Refactor the following quantization functions -
222  # re-organize and refactor for better readability.
223  def post_training_int8_no_float(self):
224    return (self._any_optimization_enabled() and
225            self._is_int8_target_required() and
226            not self._is_int16x8_target_required() and
227            not self._is_allow_float() and
228            self._representative_dataset is not None)
229
230  def post_training_int8_allow_float(self):
231    return (self._any_optimization_enabled() and
232            not self._is_int16x8_target_required() and
233            self._representative_dataset is not None and
234            self._smallest_supported_type() == _dtypes.int8)
235
236  def is_post_training_integer_quantize_8(self):
237    return (self.post_training_int8_no_float() or
238            self.post_training_int8_allow_float())
239
240  def is_post_training_integer_quantize_16x8(self):
241    return (self.post_training_int16x8_no_float() or
242            self.post_training_int16x8_allow_float())
243
244  def is_integer_quantize(self):
245    return (self.is_post_training_integer_quantize_8() or
246            self.is_post_training_integer_quantize_16x8() or
247            self.is_training_time_int8_allow_float())
248
249  def is_training_time_int8_allow_float(self):
250    return (self._any_optimization_enabled() and
251            self.contains_training_quant_op())
252
253  def post_training_int16x8_no_float(self):
254    return (self._any_optimization_enabled() and
255            not self._is_int8_target_required() and
256            self._is_int16x8_target_required() and
257            not self._is_allow_float() and
258            self._representative_dataset is not None)
259
260  def post_training_int16x8_allow_float(self):
261    return (self._any_optimization_enabled() and
262            self._is_int16x8_target_required() and
263            self._is_allow_float())
264
265  def post_training_dynamic_range_int8(self):
266    # Post-training dynamic range quantization is only enabled if post-training
267    # int8 quantization and training time quantization was not done.
268    return (self._any_optimization_enabled() and
269            self._representative_dataset is None and
270            not self.contains_training_quant_op() and
271            self._smallest_supported_type() == _dtypes.int8)
272
273  def post_training_fp16(self):
274    return (self._any_optimization_enabled() and
275            self._smallest_supported_type() == _dtypes.float16)
276
277  def fp32_execution(self):
278    """If none of the above are true."""
279    return not (self.is_integer_quantize() or
280                self.post_training_dynamic_range_int8() or
281                self.post_training_fp16())
282
283  def activations_type(self):
284    return _dtypes.int16 if self._is_int16x8_target_required() \
285      else _dtypes.int8
286
287  def converter_flags(self, inference_ty=None, inference_input_ty=None):
288    """Flags to the converter."""
289
290    if self.is_integer_quantize():
291      return {
292          "inference_type": inference_ty if inference_ty else \
293            self.activations_type(),
294          "inference_input_type": _dtypes.float32,
295          "post_training_quantize": False,  # disable dynamic range quantization
296          "quantize_to_float16": False  # disable float16 quantization
297      }
298    elif self.post_training_dynamic_range_int8():
299      return {
300          "inference_type": _dtypes.float32,
301          "inference_input_type": _dtypes.float32,
302          "post_training_quantize": True,  # enable dynamic range quantization
303          "quantize_to_float16": False  # disable float16 quantization
304      }
305    elif self.post_training_fp16():
306      return {
307          "inference_type": _dtypes.float32,
308          "inference_input_type": _dtypes.float32,
309          "post_training_quantize": True,
310          "quantize_to_float16": True  # enable float16 quantization
311      }
312    else:
313      # Note this might still trigger (uint8) quantization to be compatible with
314      # TOCO.
315      return {
316          "inference_type": inference_ty if inference_ty else _dtypes.float32,
317          "inference_input_type": inference_input_ty,
318          "post_training_quantize": False,  # enable dynamic range quantization
319          "quantize_to_float16": False  # disable float16 quantization
320      }
321
322  def quantizer_flags(self, input_ty=None, output_ty=None):
323    """Default flags to the TFMOT quantizer."""
324
325    inference_input_type = input_ty if input_ty else _dtypes.float32
326    inference_output_type = output_ty if output_ty else _dtypes.float32
327
328    if self.post_training_int8_no_float() \
329      or self.post_training_int16x8_no_float():
330      return True, {
331          "inference_input_type": inference_input_type,
332          "inference_output_type": inference_output_type,
333          "activations_type": self.activations_type(),
334          "allow_float": False
335      }
336    elif self.post_training_int8_allow_float() \
337      or self.post_training_int16x8_allow_float():
338      return True, {
339          "inference_input_type": inference_input_type,
340          "inference_output_type": inference_output_type,
341          "activations_type": self.activations_type(),
342          "allow_float": True
343      }
344    else:
345      return False, None
346
347  def flags_modify_model_io_type(self, input_ty=None, output_ty=None):
348    """Flags for modifying the input and output type of a tflite model."""
349
350    if self.is_integer_quantize():
351      return {
352          "inference_input_type": input_ty if input_ty else _dtypes.float32,
353          "inference_output_type": output_ty if output_ty else _dtypes.float32,
354      }
355    else:
356      return None
357
358  # Below are helpers for the above functions.
359
360  def _validate_int8_required(self):
361    """Int8 mode requires certain parameters to exist and be compatible."""
362    if not self._is_int8_target_required():
363      return
364
365    if self._target_spec.supported_types and (self._smallest_supported_type() !=
366                                              _dtypes.int8):
367      raise ValueError("TFLITE_BUILTINS_INT8 requires smallest supported "
368                       "type to be INT8.")
369
370    if self._representative_dataset:
371      if not isinstance(self._representative_dataset, RepresentativeDataset):
372        self._representative_dataset = RepresentativeDataset(
373            self._representative_dataset)
374      if self._representative_dataset.input_gen is None:
375        raise ValueError(
376            "Provide an input generator for representative_dataset")
377    else:
378      # TODO(b/150661651): Relax this check for QAT.
379      raise ValueError("representative_dataset is required when specifying "
380                       "TFLITE_BUILTINS_INT8 or INT8 supported types.")
381
382  def _is_int8_target_required(self):
383    return (OpsSet.TFLITE_BUILTINS_INT8 in set(
384        self._target_spec.supported_ops)) or (set(
385            self._target_spec.supported_types) == set([_dtypes.int8]))
386
387  def _is_int16x8_target_required(self):
388    return (OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
389            in set(self._target_spec.supported_ops))
390
391  def _is_allow_float(self):
392    return (OpsSet.TFLITE_BUILTINS in set(
393        self._target_spec.supported_ops)) or (OpsSet.SELECT_TF_OPS in set(
394            self._target_spec.supported_ops))
395
396  def _any_optimization_enabled(self):
397    return bool(
398        set(self._optimizations).intersection([
399            Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE,
400            Optimize.DEFAULT
401        ]))
402
403  def _smallest_supported_type(self):
404    if self._target_spec.supported_types:
405      return min(self._target_spec.supported_types, key=lambda x: x.size)
406    else:
407      # The default smallest supported type is INT8.
408      return _dtypes.int8
409
410  def contains_training_quant_op(self):
411    """Checks if the graph contains any training-time quantization ops."""
412    training_quant_ops = frozenset({
413        "FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxVarsPerChannel",
414        "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsPerChannel",
415        "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3"
416    })
417
418    for node_def in self._graph_def.node:
419      if node_def.op in training_quant_ops:
420        return True
421    for function in self._graph_def.library.function:
422      for node_def in function.node_def:
423        if node_def.op in training_quant_ops:
424          return True
425    return False
426
427
428class TFLiteConverterBase(object):
429  """Converter subclass to share functionality between V1 and V2 converters."""
430
431  def __init__(self):
432    self.optimizations = set()
433    self.representative_dataset = None
434    self.target_spec = TargetSpec()
435    self.allow_custom_ops = False
436    self.experimental_new_converter = True
437    self.experimental_new_quantizer = False
438    self._experimental_new_quantizer = None
439    self._experimental_calibrate_only = False
440    self._experimental_sparsify_model = False
441    self._debug_info = None  # contains the stack traces of all the original
442    # nodes in the `GraphDef` to the converter.
443    self.saved_model_dir = None
444    self._saved_model_tags = None
445    self._saved_model_version = 0
446    self._saved_model_exported_names = []
447
448  def _grappler_config(self, optimizers=None):
449    """Creates a tf.compat.v1.ConfigProto for configuring Grappler.
450
451    Args:
452      optimizers: List of strings that represents the list of optimizers.
453
454    Returns:
455      tf.ConfigProto.
456    """
457    if not optimizers:
458      optimizers = []
459    # MLIR converter will take care of constant folding instead of grappler.
460    if not self.experimental_new_converter:
461      optimizers.append("constfold")
462
463    is_only_flex_enabled = (
464        set([OpsSet.SELECT_TF_OPS]) == set(self.target_spec.supported_ops))
465    if is_only_flex_enabled:
466      # The layout optimizer turns NHCW to NCHW. This provides performance
467      # optimizations when Flex mode is enabled. However, this is not compatible
468      # with builtin ops.
469      optimizers.append("layout")
470    return _get_grappler_config(optimizers)
471
472  def _calibrate_quantize_model(self, result, inference_input_type,
473                                inference_output_type, activations_type,
474                                allow_float):
475    """Calibrate and quantize the model."""
476    if not isinstance(self.representative_dataset, RepresentativeDataset):
477      self.representative_dataset = RepresentativeDataset(
478          self.representative_dataset)
479
480    # Add intermediate tensors to the model if needed.
481    result = _calibrator.add_intermediate_tensors(result)
482    calibrate_quantize = _calibrator.Calibrator(result)
483    if self._experimental_calibrate_only or self.experimental_new_quantizer:
484      calibrated = calibrate_quantize.calibrate(
485          self.representative_dataset.input_gen)
486
487    if self._experimental_calibrate_only:
488      return calibrated
489    elif self.experimental_new_quantizer and (
490        activations_type != _dtypes.int16):
491      # TODO(b/175659372): remove the activations_type restriction and enable
492      # it for all the activation types.
493      return _mlir_quantize(calibrated)
494    else:
495      return calibrate_quantize.calibrate_and_quantize(
496          self.representative_dataset.input_gen, inference_input_type,
497          inference_output_type, allow_float, activations_type)
498
499  def _is_unknown_shapes_allowed(self):
500    # Unknown dimensions are only allowed with the new converter.
501    return self.experimental_new_converter
502
503  def _get_base_converter_args(self):
504    """Returns the base converter args.
505
506    Returns:
507      {key str: val}
508    """
509    args = {
510        "input_format": constants.TENSORFLOW_GRAPHDEF,
511        "allow_custom_ops": self.allow_custom_ops,
512        "debug_info": self._debug_info,
513        "target_ops": self.target_spec.supported_ops,
514        "enable_mlir_converter": self.experimental_new_converter,
515        "select_user_tf_ops": self.target_spec.experimental_select_user_tf_ops,
516    }
517
518    if self.saved_model_dir:
519      args.update({
520          "saved_model_dir": self.saved_model_dir,
521          "saved_model_version": self._saved_model_version,
522          "saved_model_tags": self._saved_model_tags,
523          "saved_model_exported_names": self._saved_model_exported_names,
524      })
525
526    return args
527
528  def _contains_function_with_implements_attr(self, saved_model_proto):
529    meta_graph = saved_model_proto.meta_graphs[0]
530    for function in meta_graph.graph_def.library.function:
531      if function.attr.get("_implements", None) or function.attr.get(
532          "api_implements", None):
533        return True
534    return False
535
536  def _parse_saved_model_args(self, always_enable_saved_model_import=False):
537    """Parses SavedModel arguments from the given Keras/RNN SavedModel.
538
539    Args:
540      always_enable_saved_model_import: Bool. When the value is true, it enables
541        MLIR saved model import path regardless of checking the conditions.
542    """
543    if not self.experimental_new_converter:
544      self.saved_model_dir = None
545      return
546    if self.saved_model_dir:
547      try:
548        saved_model_proto, _ = (
549            _parse_saved_model_with_debug_info(self.saved_model_dir))
550      except OSError:
551        # If it fails to read the given saved model, it will fall back to the
552        # frozen graph def path.
553        self.saved_model_dir = None
554        return
555      if (not always_enable_saved_model_import and
556          not self._contains_function_with_implements_attr(saved_model_proto)):
557        self.saved_model_dir = None
558        return
559
560      if not self._saved_model_exported_names:
561        self._saved_model_exported_names = []
562      self._saved_model_version = saved_model_proto.saved_model_schema_version
563      if self._saved_model_version == 0:
564        self.saved_model_dir = None
565        logging.warning("SavedModel schema version is zero.")
566        return
567      if self._saved_model_version not in [1, 2]:
568        raise ValueError("SavedModel file format({0}) is not supported".format(
569            self._saved_model_version))
570
571  def _sparsify_model(self):
572    return Optimize.EXPERIMENTAL_SPARSITY in self.optimizations
573
574  def _validate_experimental_new_quantizer_flag(self):
575    if self._experimental_new_quantizer is not None:
576      raise ValueError("Please use 'experimental_new_quantizer' instead.")
577
578
579class TFLiteConverterBaseV2(TFLiteConverterBase):
580  """Converter subclass to share functionality between V2 converters."""
581
582  def __init__(self):
583    """Constructor for TFLiteConverter."""
584    super(TFLiteConverterBaseV2, self).__init__()
585    self.inference_input_type = _dtypes.float32
586    self.inference_output_type = _dtypes.float32
587
588  def _validate_inference_input_output_types(self, quant_mode):
589    """Validate inference_input_type and inference_output_type flags."""
590    default_types = [_dtypes.float32]
591    # We support integer input/output for integer quantized models only.
592    if quant_mode.is_integer_quantize():
593      if quant_mode.is_post_training_integer_quantize_16x8():
594        all_types = default_types + [_dtypes.int16]
595      else:
596        all_types = default_types + [_dtypes.int8, _dtypes.uint8]
597      if self.inference_input_type not in all_types or \
598          self.inference_output_type not in all_types:
599        all_types_names = ["tf." + t.name for t in all_types]
600        raise ValueError("The inference_input_type and inference_output_type "
601                         "must be in {}.".format(all_types_names))
602    elif self.inference_input_type not in default_types or \
603        self.inference_output_type not in default_types:
604      raise ValueError("The inference_input_type and inference_output_type "
605                       "must be tf.float32.")
606
607  def convert(self, graph_def, input_tensors, output_tensors):
608    """Converts a TensorFlow GraphDef based on instance variables.
609
610    Args:
611      graph_def: Frozen TensorFlow GraphDef.
612      input_tensors: List of input tensors. Type and shape are computed using
613        `foo.shape` and `foo.dtype`.
614      output_tensors: List of output tensors (only .name is used from this).
615
616    Returns:
617      The converted data in serialized format.
618
619    Raises:
620      ValueError:
621        No concrete functions is specified.
622        Multiple concrete functions are specified.
623        Input shape is not specified.
624        Invalid quantization parameters.
625    """
626    quant_mode = QuantizationMode(self.optimizations, self.target_spec,
627                                  self.representative_dataset, graph_def)
628
629    self._validate_inference_input_output_types(quant_mode)
630    self._validate_experimental_new_quantizer_flag()
631
632    if not self._is_unknown_shapes_allowed():
633      # Checks dimensions in input tensor.
634      for tensor in input_tensors:
635        # Note that shape_list might be empty for scalar shapes.
636        shape_list = tensor.shape.as_list()
637        if None in shape_list[1:]:
638          raise ValueError(
639              "None is only supported in the 1st dimension. Tensor '{0}' has "
640              "invalid shape '{1}'.".format(
641                  _get_tensor_name(tensor), shape_list))
642        elif shape_list and shape_list[0] is None:
643          # Set the batch size to 1 if undefined.
644          shape = tensor.shape.as_list()
645          shape[0] = 1
646          tensor.set_shape(shape)
647
648    if self._trackable_obj is None:
649      self._debug_info = _get_debug_info(
650          _build_debug_info_func(self._funcs[0].graph), graph_def)
651    else:
652      self._debug_info = _get_debug_info(
653          _convert_debug_info_func(self._trackable_obj.graph_debug_info),
654          graph_def)
655
656    converter_kwargs = self._get_base_converter_args()
657    converter_kwargs.update(quant_mode.converter_flags())
658    if not self.experimental_new_converter:
659      logging.warning(
660          "Please consider switching to the new converter by setting "
661          "experimental_new_converter=True. "
662          "The old converter (TOCO) is deprecated.")
663    else:
664      logging.info("Using new converter: If you encounter a problem "
665                   "please file a bug. You can opt-out "
666                   "by setting experimental_new_converter=False")
667
668    # Converts model.
669    result = _toco_convert_impl(
670        input_data=graph_def,
671        input_tensors=input_tensors,
672        output_tensors=output_tensors,
673        **converter_kwargs)
674
675    calibrate_and_quantize, flags = quant_mode.quantizer_flags()
676    if calibrate_and_quantize:
677      result = self._calibrate_quantize_model(result, **flags)
678
679    flags_modify_model_io_type = quant_mode.flags_modify_model_io_type(
680        self.inference_input_type, self.inference_output_type)
681    if flags_modify_model_io_type:
682      result = _modify_model_io_type(result, **flags_modify_model_io_type)
683
684    if self._sparsify_model():
685      result = _mlir_sparsify(result)
686
687    return result
688
689
690class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
691  """Converts the given SavedModel into TensorFlow Lite model.
692
693  Attributes:
694      saved_model_dir: Directory of the SavedModel.
695  """
696
697  def __init__(self,
698               saved_model_dir,
699               saved_model_tags=None,
700               saved_model_exported_names=None,
701               trackable_obj=None):
702    """Constructor for TFLiteConverter.
703
704    Args:
705      saved_model_dir: Directory of the SavedModel.
706      saved_model_tags: Set of tags identifying the MetaGraphDef within the
707        SavedModel to analyze. All tags in the tag set must be present. (default
708        {tf.saved_model.SERVING}).
709      saved_model_exported_names: Names to be exported when the saved model
710        import path is on.
711      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
712        reference to this object needs to be maintained so that Variables do not
713        get garbage collected since functions have a weak reference to
714        Variables. This is only required when the tf.AutoTrackable object is not
715        maintained by the user (e.g. `from_saved_model`).
716    """
717    super(TFLiteSavedModelConverterV2, self).__init__()
718    self.saved_model_dir = saved_model_dir
719    self._saved_model_tags = saved_model_tags
720    self._saved_model_exported_names = saved_model_exported_names
721    self._trackable_obj = trackable_obj
722    self._parse_saved_model_args(always_enable_saved_model_import=True)
723    self._enable_tflite_resource_variables = False
724
725  def convert(self):
726    """Converts a TensorFlow GraphDef based on instance variables.
727
728    Returns:
729      The converted data in serialized format.
730
731    Raises:
732      ValueError:
733        No concrete functions is specified.
734        Multiple concrete functions are specified.
735        Input shape is not specified.
736        Invalid quantization parameters.
737    """
738    graph = _ops.Graph()
739    saved_model = _loader_impl.SavedModelLoader(self.saved_model_dir)
740    saved_model.load_graph(graph, tags=self._saved_model_tags)
741    meta_graph = saved_model.get_meta_graph_def_from_tags(
742        self._saved_model_tags)
743    # If we can't use saved model importer, then fallback
744    # to frozen graph conversion path.
745    if self.saved_model_dir is None or not self.experimental_new_converter:
746      signature_def = meta_graph.signature_def[
747          _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
748      input_tensors = [
749          graph.get_tensor_by_name(signature_def.inputs[key].name)
750          for key in signature_def.inputs
751      ]
752      output_tensors = [
753          graph.get_tensor_by_name(signature_def.outputs[key].name)
754          for key in signature_def.outputs
755      ]
756      result = _freeze_saved_model(
757          self.saved_model_dir, None, None, None, self._saved_model_tags,
758          _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
759      graph_def = result[0]
760      # We make sure to clear the saved_model_dir as there is some
761      # legacy code down in the caller that checks this.
762      # TODO(b/162537905): Clean these indirect dependencies.
763      self.saved_model_dir = None
764      return super(TFLiteSavedModelConverterV2,
765                   self).convert(graph_def, input_tensors, output_tensors)
766
767    if self._trackable_obj is None:
768      self._debug_info = _get_debug_info(
769          _build_debug_info_func(self._funcs[0].graph), meta_graph.graph_def)
770    else:
771      self._debug_info = _get_debug_info(
772          _convert_debug_info_func(self._trackable_obj.graph_debug_info),
773          meta_graph.graph_def)
774
775    # Get quantization options and do some sanity checks.
776    quant_mode = QuantizationMode(self.optimizations, self.target_spec,
777                                  self.representative_dataset,
778                                  meta_graph.graph_def)
779    self._validate_inference_input_output_types(quant_mode)
780
781    converter_kwargs = self._get_base_converter_args()
782    converter_kwargs.update(quant_mode.converter_flags())
783    converter_kwargs.update({
784        "enable_tflite_resource_variables":
785            self._enable_tflite_resource_variables
786    })
787
788    result = _convert_saved_model(**converter_kwargs)
789    calibrate_and_quantize, flags = quant_mode.quantizer_flags()
790    if calibrate_and_quantize:
791      result = self._calibrate_quantize_model(result, **flags)
792
793    flags_modify_model_io_type = quant_mode.flags_modify_model_io_type(
794        self.inference_input_type, self.inference_output_type)
795    if flags_modify_model_io_type:
796      result = _modify_model_io_type(result, **flags_modify_model_io_type)
797
798    if self._sparsify_model():
799      result = _mlir_sparsify(result)
800
801    return result
802
803
804class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
805  """Converts the given Keras model into TensorFlow Lite model."""
806
807  def __init__(self, keras_model, trackable_obj=None):
808    """Constructor for TFLiteConverter.
809
810    Args:
811      keras_model: tf.Keras.Model.
812      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
813        reference to this object needs to be maintained so that Variables do not
814        get garbage collected since functions have a weak reference to
815        Variables. This is only required when the tf.AutoTrackable object is not
816        maintained by the user (e.g. `from_saved_model`).
817    """
818    super(TFLiteKerasModelConverterV2, self).__init__()
819    self._keras_model = keras_model
820    self._trackable_obj = trackable_obj
821
822  def _convert_as_saved_model(self):
823    """Converts a Keras model as a saved model.
824
825    Returns:
826      The converted data in serialized format.
827    """
828    temp_dir = tempfile.mkdtemp()
829    try:
830      try:
831        self._keras_model.save(temp_dir, save_format="tf")
832      except Exception:  # pylint: disable=broad-except
833        # When storing the given keras model to a saved model is failed, let's
834        # use original keras model conversion pipeline.
835        return None
836      self.saved_model_dir = temp_dir
837      self._saved_model_tags = set([_tag_constants.SERVING])
838      self._saved_model_exported_names = [
839          _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
840      ]
841      self._parse_saved_model_args()
842      if self.saved_model_dir:
843        graph = _ops.Graph()
844        saved_model = _loader_impl.SavedModelLoader(self.saved_model_dir)
845        saved_model.load_graph(graph, tags=self._saved_model_tags)
846        meta_graph = saved_model.get_meta_graph_def_from_tags(
847            self._saved_model_tags)
848        signature_def = meta_graph.signature_def[
849            _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
850        input_tensors = [
851            graph.get_tensor_by_name(signature_def.inputs[key].name)
852            for key in signature_def.inputs
853        ]
854        output_tensors = [
855            graph.get_tensor_by_name(signature_def.outputs[key].name)
856            for key in signature_def.outputs
857        ]
858        self._trackable_obj = _load(self.saved_model_dir,
859                                    self._saved_model_tags)
860        return super(TFLiteKerasModelConverterV2,
861                     self).convert(meta_graph.graph_def, input_tensors,
862                                   output_tensors)
863    finally:
864      shutil.rmtree(temp_dir, True)
865
866  def convert(self):
867    """Converts a keras model based on instance variables.
868
869    Returns:
870      The converted data in serialized format.
871
872    Raises:
873      ValueError:
874        Multiple concrete functions are specified.
875        Input shape is not specified.
876        Invalid quantization parameters.
877    """
878    saved_model_convert_result = self._convert_as_saved_model()
879    if saved_model_convert_result:
880      return saved_model_convert_result
881
882    input_signature = None
883    # If the model's call is not a `tf.function`, then we need to first get its
884    # input signature from `model_input_signature` method. We can't directly
885    # call `trace_model_call` because otherwise the batch dimension is set
886    # to None.
887    # Once we have better support for dynamic shapes, we can remove this.
888    if not isinstance(self._keras_model.call, _def_function.Function):
889      # Pass `keep_original_batch_size=True` will ensure that we get an input
890      # signature including the batch dimension specified by the user.
891      # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
892      input_signature = _model_input_signature(
893          self._keras_model, keep_original_batch_size=True)
894
895    # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
896    func = _trace_model_call(self._keras_model, input_signature)
897    concrete_func = func.get_concrete_function()
898    self._funcs = [concrete_func]
899
900    frozen_func, graph_def = (
901        _convert_to_constants.convert_variables_to_constants_v2_as_graph(
902            self._funcs[0], lower_control_flow=False))
903
904    input_tensors = [
905        tensor for tensor in frozen_func.inputs
906        if tensor.dtype != _dtypes.resource
907    ]
908    output_tensors = frozen_func.outputs
909
910    # Run a Grappler pass.
911    grappler_config = self._grappler_config()
912    # Skip running grappler when there are no optimizers to run. If not,
913    # grappler will run with the default optimizer set and it will lead to
914    # causing an unexpected behavior.
915    if grappler_config.graph_options.rewrite_options.optimizers:
916      graph_def = _run_graph_optimizations(
917          graph_def,
918          input_tensors,
919          output_tensors,
920          config=grappler_config,
921          graph=frozen_func.graph)
922
923    return super(TFLiteKerasModelConverterV2,
924                 self).convert(graph_def, input_tensors, output_tensors)
925
926
927class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2):
928  """Converts the given frozen graph into TensorFlow Lite model."""
929
930  def __init__(self, funcs, trackable_obj=None):
931    """Constructor for TFLiteConverter.
932
933    Args:
934      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
935        duplicate elements.
936      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
937        reference to this object needs to be maintained so that Variables do not
938        get garbage collected since functions have a weak reference to
939        Variables. This is only required when the tf.AutoTrackable object is not
940        maintained by the user (e.g. `from_saved_model`).
941    """
942    super(TFLiteFrozenGraphConverterV2, self).__init__()
943    self._funcs = funcs
944    self._trackable_obj = trackable_obj
945
946  def convert(self):
947    """Converts a TensorFlow GraphDef based on instance variables.
948
949    Returns:
950      The converted data in serialized format.
951
952    Raises:
953      ValueError:
954        No concrete functions is specified.
955        Multiple concrete functions are specified.
956        Input shape is not specified.
957        Invalid quantization parameters.
958    """
959    # TODO(b/130297984): Add support for converting multiple function.
960
961    if len(self._funcs) == 0:  # pylint: disable=g-explicit-length-test
962      raise ValueError("No ConcreteFunction is specified.")
963
964    if len(self._funcs) > 1:
965      raise ValueError("This converter can only convert a single "
966                       "ConcreteFunction. Converting multiple functions is "
967                       "under development.")
968
969    frozen_func, graph_def = (
970        _convert_to_constants.convert_variables_to_constants_v2_as_graph(
971            self._funcs[0], lower_control_flow=False))
972
973    input_tensors = [
974        tensor for tensor in frozen_func.inputs
975        if tensor.dtype != _dtypes.resource
976    ]
977    output_tensors = frozen_func.outputs
978
979    # Run a Grappler pass.
980    grappler_config = self._grappler_config()
981    # Skip running grappler when there are no optimizers to run. If not,
982    # grappler will run with the default optimizer set and it will lead to
983    # causing an unexpected behavior.
984    if grappler_config.graph_options.rewrite_options.optimizers:
985      graph_def = _run_graph_optimizations(
986          graph_def,
987          input_tensors,
988          output_tensors,
989          config=grappler_config,
990          graph=frozen_func.graph)
991
992    return super(TFLiteFrozenGraphConverterV2,
993                 self).convert(graph_def, input_tensors, output_tensors)
994
995
996@_tf_export("lite.TFLiteConverter", v1=[])
997class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
998  """Converts a TensorFlow model into TensorFlow Lite model.
999
1000  Attributes:
1001    optimizations: Experimental flag, subject to change. Set of optimizations
1002      to apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
1003      set of values of type `tf.lite.Optimize`)
1004    representative_dataset: A generator function used for integer quantization
1005      where each generated sample has the same order, type and shape as the
1006      inputs to the model. Usually, this is a small subset of a few hundred
1007      samples randomly chosen, in no particular order, from the training or
1008      evaluation dataset. This is an optional attribute, but required for full
1009      integer quantization, i.e, if `tf.int8` is the only supported type in
1010      `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
1011      (default None)
1012    target_spec: Experimental flag, subject to change. Specifications of target
1013      device, including supported ops set, supported types and a set of user's
1014      defined TensorFlow operators required in the TensorFlow Lite runtime.
1015      Refer to `tf.lite.TargetSpec`.
1016    inference_input_type: Data type of the input layer. Note that integer types
1017      (tf.int8 and tf.uint8) are currently only supported for post training
1018      integer quantization and quantization aware training. (default tf.float32,
1019      must be in {tf.float32, tf.int8, tf.uint8})
1020    inference_output_type: Data type of the output layer. Note that integer
1021      types (tf.int8 and tf.uint8) are currently only supported for post
1022      training integer quantization and quantization aware training. (default
1023      tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
1024    allow_custom_ops: Boolean indicating whether to allow custom operations.
1025      When False, any unknown operation is an error. When True, custom ops are
1026      created for any op that is unknown. The developer needs to provide these
1027      to the TensorFlow Lite runtime with a custom resolver. (default False)
1028    experimental_new_converter: Experimental flag, subject to change. Enables
1029      MLIR-based conversion instead of TOCO conversion. (default True)
1030    experimental_new_quantizer: Experimental flag, subject to change. Enables
1031      MLIR-based quantization conversion instead of Flatbuffer-based conversion.
1032      (default False)
1033
1034  Example usage:
1035
1036    ```python
1037    # Converting a SavedModel to a TensorFlow Lite model.
1038    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
1039    tflite_model = converter.convert()
1040
1041    # Converting a tf.Keras model to a TensorFlow Lite model.
1042    converter = tf.lite.TFLiteConverter.from_keras_model(model)
1043    tflite_model = converter.convert()
1044
1045    # Converting ConcreteFunctions to a TensorFlow Lite model.
1046    converter = tf.lite.TFLiteConverter.from_concrete_functions([func])
1047    tflite_model = converter.convert()
1048    ```
1049  """
1050
1051  # pylint: disable=useless-super-delegation
1052  def __init__(self, funcs, trackable_obj=None):
1053    """Constructor for TFLiteConverter.
1054
1055    Args:
1056      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
1057        duplicate elements.
1058      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
1059        reference to this object needs to be maintained so that Variables do not
1060        get garbage collected since functions have a weak reference to
1061        Variables. This is only required when the tf.AutoTrackable object is not
1062        maintained by the user (e.g. `from_saved_model`).
1063    """
1064    super(TFLiteConverterV2, self).__init__(funcs, trackable_obj)
1065
1066  @classmethod
1067  def from_concrete_functions(cls, funcs):
1068    """Creates a TFLiteConverter object from ConcreteFunctions.
1069
1070    Args:
1071      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
1072        duplicate elements. Currently converter can only convert a single
1073        ConcreteFunction. Converting multiple functions is under development.
1074
1075    Returns:
1076      TFLiteConverter object.
1077
1078    Raises:
1079      Invalid input type.
1080    """
1081    for func in funcs:
1082      if not isinstance(func, _function.ConcreteFunction):
1083        message = "This function takes in a list of ConcreteFunction."
1084        if isinstance(func, _def_function.Function):
1085          message += (" To get the ConcreteFunction from a Function,"
1086                      " call get_concrete_function.")
1087        raise ValueError(message)
1088    return cls(funcs)
1089
1090  @classmethod
1091  def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None):
1092    """Creates a TFLiteConverter object from a SavedModel directory.
1093
1094    Args:
1095      saved_model_dir: SavedModel directory to convert.
1096      signature_keys: List of keys identifying SignatureDef containing inputs
1097        and outputs. Elements should not be duplicated. By default the
1098        `signatures` attribute of the MetaGraphdef is used. (default
1099        saved_model.signatures)
1100      tags: Set of tags identifying the MetaGraphDef within the SavedModel to
1101        analyze. All tags in the tag set must be present. (default
1102        {tf.saved_model.SERVING} or {'serve'})
1103
1104    Returns:
1105      TFLiteConverter object.
1106
1107    Raises:
1108      Invalid signature keys.
1109    """
1110    # When run without eager enabled, this will return the legacy
1111    # TFLiteConverter.
1112    if not context.executing_eagerly():
1113      signature_key = None
1114      if signature_keys:
1115        if len(signature_keys) != 1:
1116          raise ValueError("Only support a single signature key.")
1117        else:
1118          signature_key = signature_keys[0]
1119      logging.warning("Invoking the TF1 implementation of TFLiteConverter "
1120                      "because eager is disabled. Consider enabling eager.")
1121      return TFLiteConverter.from_saved_model(
1122          saved_model_dir, signature_key=signature_key, tag_set=tags)
1123
1124    # Ensures any graphs created in Eager mode are able to run. This is required
1125    # in order to create a tf.estimator.Exporter that exports a TFLite model.
1126    if tags is None:
1127      tags = set([_tag_constants.SERVING])
1128
1129    with context.eager_mode():
1130      saved_model = _load(saved_model_dir, tags)
1131    if not signature_keys:
1132      signature_keys = saved_model.signatures
1133
1134    if len(signature_keys) != 1:
1135      raise ValueError("Only support a single signature key.")
1136
1137    funcs = []
1138    for key in signature_keys:
1139      if key not in saved_model.signatures:
1140        raise ValueError("Invalid signature key '{}' found. Valid keys are "
1141                         "'{}'.".format(key, ",".join(saved_model.signatures)))
1142      funcs.append(saved_model.signatures[key])
1143
1144    saved_model_converter = TFLiteSavedModelConverterV2(saved_model_dir, tags,
1145                                                        signature_keys,
1146                                                        saved_model)
1147    if saved_model_converter.saved_model_dir:
1148      return saved_model_converter
1149
1150    return cls(funcs, saved_model)
1151
1152  @classmethod
1153  def from_keras_model(cls, model):
1154    """Creates a TFLiteConverter object from a Keras model.
1155
1156    Args:
1157      model: tf.Keras.Model
1158
1159    Returns:
1160      TFLiteConverter object.
1161    """
1162    return TFLiteKerasModelConverterV2(model)
1163
1164  # pylint: disable=useless-super-delegation
1165  def convert(self):
1166    """Converts a TensorFlow GraphDef based on instance variables.
1167
1168    Returns:
1169      The converted data in serialized format.
1170
1171    Raises:
1172      ValueError:
1173        No concrete functions is specified.
1174        Multiple concrete functions are specified.
1175        Input shape is not specified.
1176        Invalid quantization parameters.
1177    """
1178    return super(TFLiteConverterV2, self).convert()
1179
1180
1181class TFLiteConverterBaseV1(TFLiteConverterBase):
1182  """Converter subclass to share functionality between V1 converters."""
1183
1184  def __init__(self, experimental_debug_info_func):
1185    """Constructor for TFLiteConverter.
1186
1187    Args:
1188      experimental_debug_info_func: An experimental function to retrieve the
1189        graph debug info for a set of nodes from the `graph_def`.
1190    """
1191    super(TFLiteConverterBaseV1, self).__init__()
1192    self.inference_type = _dtypes.float32
1193    self.inference_input_type = None
1194    self.inference_output_type = None
1195    self.output_format = constants.TFLITE
1196    self.quantized_input_stats = {}
1197    self.default_ranges_stats = None
1198    self.drop_control_dependency = True
1199    self.reorder_across_fake_quant = False
1200    self.change_concat_input_ranges = False
1201    self.dump_graphviz_dir = None
1202    self.dump_graphviz_video = False
1203    self.conversion_summary_dir = None
1204    self._debug_info_func = experimental_debug_info_func
1205    self._custom_opdefs = None
1206
1207  def __setattr__(self, name, value):
1208    if name == "post_training_quantize":
1209      warnings.warn("Property %s is deprecated, "
1210                    "please use optimizations=[Optimize.DEFAULT]"
1211                    " instead." % name)
1212      if value:
1213        self.optimizations = [Optimize.DEFAULT]
1214      else:
1215        self.optimizations = []
1216      return
1217    if name == "target_ops":
1218      warnings.warn("Property %s is deprecated, please use "
1219                    "target_spec.supported_ops instead." % name)
1220      self.target_spec.supported_ops = value
1221      return
1222    object.__setattr__(self, name, value)
1223
1224  def __getattribute__(self, name):
1225    if name == "post_training_quantize":
1226      warnings.warn("Property %s is deprecated, "
1227                    "please use optimizations=[Optimize.DEFAULT]"
1228                    " instead." % name)
1229      return Optimize.DEFAULT in set(self.optimizations)
1230    if name == "target_ops":
1231      warnings.warn("Property %s is deprecated, please use "
1232                    "target_spec.supported_ops instead." % name)
1233      return self.target_spec.supported_ops
1234    return object.__getattribute__(self, name)
1235
1236  def _validate_quantized_input_stats(self, converter_kwargs, calibrate):
1237    """Ensure the `quantized_input_stats` flag is provided if required."""
1238
1239    quantized_types = frozenset({_dtypes.int8, _dtypes.uint8})
1240
1241    requires_quantized_input_stats = (
1242        (converter_kwargs["inference_type"] in quantized_types or
1243         converter_kwargs["inference_input_type"] in quantized_types) and
1244        not calibrate)
1245
1246    if (requires_quantized_input_stats and
1247        not converter_kwargs["quantized_input_stats"]):
1248      raise ValueError(
1249          "The `quantized_input_stats` flag must be defined when either "
1250          "`inference_type` flag or `inference_input_type` flag is set to "
1251          "tf.int8 or tf.uint8. Currently, `inference_type={}` and "
1252          "`inference_input_type={}`.".format(
1253              _get_tf_type_name(converter_kwargs["inference_type"]),
1254              _get_tf_type_name(converter_kwargs["inference_input_type"])))
1255
1256  def convert(self):
1257    """Converts a TensorFlow GraphDef based on instance variables.
1258
1259    Returns:
1260      The converted data in serialized format. Either a TFLite Flatbuffer or a
1261      Graphviz graph depending on value in `output_format`.
1262
1263    Raises:
1264      ValueError:
1265        Input shape is not specified.
1266        None value for dimension in input_tensor.
1267    """
1268    quant_mode = QuantizationMode(self.optimizations, self.target_spec,
1269                                  self.representative_dataset, self._graph_def)
1270
1271    if (not self._is_unknown_shapes_allowed() and self._has_valid_tensors()):
1272      # Checks dimensions in input tensor.
1273      for tensor in self._input_tensors:
1274        shape = tensor.shape
1275        if not shape:
1276          raise ValueError("Provide an input shape for input array "
1277                           "'{0}'.".format(_get_tensor_name(tensor)))
1278        # Note that shape_list might be empty for scalar shapes.
1279        shape_list = shape.as_list()
1280        if None in shape_list[1:]:
1281          raise ValueError(
1282              "None is only supported in the 1st dimension. Tensor '{0}' has "
1283              "invalid shape '{1}'.".format(
1284                  _get_tensor_name(tensor), shape_list))
1285        elif shape_list and shape_list[0] is None:
1286          self._set_batch_size(batch_size=1)
1287
1288    # Get quantization stats. Ensures there is one stat per name if the stats
1289    # are specified.
1290    if self.quantized_input_stats:
1291      quantized_stats = []
1292      invalid_stats = []
1293      for name in self.get_input_arrays():
1294        if name in self.quantized_input_stats:
1295          quantized_stats.append(self.quantized_input_stats[name])
1296        else:
1297          invalid_stats.append(name)
1298
1299      if invalid_stats:
1300        raise ValueError("Quantization input stats are not available for input "
1301                         "tensors '{0}'.".format(",".join(invalid_stats)))
1302    else:
1303      quantized_stats = None
1304
1305    optimized_graph = self._graph_def
1306    if not self.saved_model_dir:
1307      # Disable grappler constant folding if there are training quant ops.
1308      if not quant_mode.contains_training_quant_op():
1309        try:
1310          # TODO(b/150163103): Merge `disabling lower using switch merge' calls.
1311          # Grappler will also try to lower while loop into switch merge
1312          # representation which is undesired for Ophints, so we simply remove
1313          # those attributes to prevent Grappler from doing so.
1314          graph_def = _convert_to_constants.disable_lower_using_switch_merge(
1315              optimized_graph)
1316          # Run function inlining optimization to ensure any models generated
1317          # through the from_frozen_graph path have been inlined.
1318          optimized_graph = _run_graph_optimizations(
1319              graph_def,
1320              self._input_tensors,
1321              self._output_tensors,
1322              config=self._grappler_config(["function"]))
1323        except Exception:  # pylint: disable=broad-except
1324          optimized_graph = self._graph_def
1325
1326    self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
1327
1328    converter_kwargs = self._get_base_converter_args()
1329    converter_kwargs.update(
1330        quant_mode.converter_flags(self.inference_type,
1331                                   self.inference_input_type))
1332    converter_kwargs.update({
1333        "output_format": self.output_format,
1334        "quantized_input_stats": quantized_stats,
1335        "default_ranges_stats": self.default_ranges_stats,
1336        "drop_control_dependency": self.drop_control_dependency,
1337        "reorder_across_fake_quant": self.reorder_across_fake_quant,
1338        "change_concat_input_ranges": self.change_concat_input_ranges,
1339        "dump_graphviz_dir": self.dump_graphviz_dir,
1340        "dump_graphviz_video": self.dump_graphviz_video,
1341        "conversion_summary_dir": self.conversion_summary_dir,
1342        "custom_opdefs": self._custom_opdefs,
1343    })
1344
1345    if not self.experimental_new_converter:
1346      logging.warning(
1347          "Please consider switching to the new converter by setting "
1348          "experimental_new_converter=True. "
1349          "The old converter (TOCO) is deprecated.")
1350    else:
1351      logging.info("Using experimental converter: If you encountered a problem "
1352                   "please file a bug. You can opt-out "
1353                   "by setting experimental_new_converter=False")
1354
1355    if not self.experimental_new_converter:
1356      calibrate_quantize, flags = quant_mode.quantizer_flags(
1357          self.inference_input_type, self.inference_output_type)
1358    else:
1359      calibrate_quantize, flags = quant_mode.quantizer_flags()
1360
1361    self._validate_quantized_input_stats(converter_kwargs, calibrate_quantize)
1362    self._validate_experimental_new_quantizer_flag()
1363
1364    # Converts model.
1365    if self._has_valid_tensors():
1366      result = _toco_convert_impl(
1367          input_data=optimized_graph,
1368          input_tensors=self._input_tensors,
1369          output_tensors=self._output_tensors,
1370          **converter_kwargs)
1371    else:
1372      result = _toco_convert_graph_def(
1373          input_data=optimized_graph,
1374          input_arrays_with_shape=self._input_arrays_with_shape,
1375          output_arrays=self._output_arrays,
1376          **converter_kwargs)
1377
1378    if calibrate_quantize:
1379      result = self._calibrate_quantize_model(result, **flags)
1380
1381    if self.experimental_new_converter or self.experimental_new_quantizer:
1382      flags_modify_model_io_type = quant_mode.flags_modify_model_io_type(
1383          self.inference_input_type, self.inference_output_type)
1384      if flags_modify_model_io_type:
1385        result = _modify_model_io_type(result, **flags_modify_model_io_type)
1386
1387    if self._sparsify_model():
1388      result = _mlir_sparsify(result)
1389
1390    return result
1391
1392  def get_input_arrays(self):
1393    """Returns a list of the names of the input tensors.
1394
1395    Returns:
1396      List of strings.
1397    """
1398    if self._has_valid_tensors():
1399      return [_get_tensor_name(tensor) for tensor in self._input_tensors]
1400    else:
1401      return [name for name, _ in self._input_arrays_with_shape]
1402
1403  def _has_valid_tensors(self):
1404    """Checks if the input and output tensors have been initialized.
1405
1406    Returns:
1407      Bool.
1408    """
1409    return self._input_tensors is not None and self._output_tensors
1410
1411  def _set_batch_size(self, batch_size):
1412    """Sets the first dimension of the input tensor to `batch_size`.
1413
1414    Args:
1415      batch_size: Batch size for the model. Replaces the first dimension of an
1416        input size array if undefined. (default 1)
1417
1418    Raises:
1419      ValueError: input_tensor is not defined.
1420    """
1421    if not self._has_valid_tensors():
1422      raise ValueError("The batch size cannot be set for this model. Please "
1423                       "use input_shapes parameter.")
1424
1425    for tensor in self._input_tensors:
1426      shape = tensor.shape.as_list()
1427      if shape[0] is None:
1428        shape[0] = batch_size
1429        tensor.set_shape(shape)
1430
1431  def _is_unknown_shapes_allowed(self):
1432    # Ophint Converted nodes will need the shapes to be known.
1433    if _is_ophint_converted(self._graph_def):
1434      return False
1435
1436    if not super(TFLiteConverterBaseV1, self)._is_unknown_shapes_allowed():
1437      return False
1438
1439    # `conversion_summary_dir` calls TOCO. Unknown shapes are only supported by
1440    # the MLIR converter.
1441    if self.conversion_summary_dir:
1442      logging.warning(
1443          "`conversion_summary_dir` does not work with unknown shapes. "
1444          "Graphs with unknown shapes might be different than when this flag "
1445          "is disabled.")
1446      return False
1447    return True
1448
1449
1450class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
1451  """Converts the given SavedModel into TensorFlow Lite model.
1452
1453  Attributes:
1454      saved_model_dir: Directory of the SavedModel.
1455  """
1456
1457  def __init__(self,
1458               saved_model_dir,
1459               saved_model_tags,
1460               saved_model_exported_names,
1461               experimental_debug_info_func=None):
1462    """Constructor for TFLiteConverter.
1463
1464    Args:
1465      saved_model_dir: Directory of the SavedModel.
1466      saved_model_tags: Set of tags identifying the MetaGraphDef within the
1467        SavedModel to analyze. All tags in the tag set must be present. (default
1468        {tf.saved_model.SERVING}).
1469      saved_model_exported_names: Names to be exported when the saved model
1470        import path is on.
1471      experimental_debug_info_func: An experimental function to retrieve the
1472        graph debug info for a set of nodes from the `graph_def`.
1473
1474    Raises:
1475      ValueError: Invalid arguments.
1476    """
1477    super(TFLiteSavedModelConverter,
1478          self).__init__(experimental_debug_info_func)
1479    self.saved_model_dir = saved_model_dir
1480    self._saved_model_tags = saved_model_tags
1481    self._saved_model_exported_names = saved_model_exported_names
1482
1483    signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
1484
1485    if len(self._saved_model_exported_names) != 1:
1486      raise ValueError("Only support a single signature key.")
1487
1488    signature_key = self._saved_model_exported_names[0]
1489
1490    result = _freeze_saved_model(self.saved_model_dir, None, None, None,
1491                                 self._saved_model_tags, signature_key)
1492    self._graph_def = result[0]
1493    self._input_tensors = result[1]
1494    self._output_tensors = result[2]
1495    self._parse_saved_model_args()
1496
1497
1498class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
1499  """Converts the given SavedModel into TensorFlow Lite model."""
1500
1501  def __init__(self,
1502               model_file,
1503               input_arrays=None,
1504               input_shapes=None,
1505               output_arrays=None,
1506               custom_objects=None):
1507    """Constructor for TFLiteConverter.
1508
1509    Args:
1510      model_file: Full filepath of HDF5 file containing the tf.keras model.
1511      input_arrays: List of input tensors to freeze graph with. Uses input
1512        arrays from SignatureDef when none are provided. (default None)
1513      input_shapes: Dict of strings representing input tensor names to list of
1514        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
1515        Automatically determined when input shapes is None (e.g., {"foo" :
1516          None}). (default None)
1517      output_arrays: List of output tensors to freeze graph with. Uses output
1518        arrays from SignatureDef when none are provided. (default None)
1519      custom_objects: Dict mapping names (strings) to custom classes or
1520        functions to be considered during model deserialization. (default None)
1521
1522    Raises:
1523      ValueError: Invalid arguments.
1524    """
1525    super(TFLiteKerasModelConverter,
1526          self).__init__(experimental_debug_info_func=None)
1527    # Handles Keras when Eager mode is enabled.
1528    if context.executing_eagerly():
1529      if input_arrays or output_arrays:
1530        raise ValueError("`input_arrays` and `output_arrays` are unsupported "
1531                         "with Eager mode. If your model requires any of these "
1532                         "parameters, please use disable_eager_execution().")
1533
1534      keras_model = keras_deps.get_load_model_function()(
1535          model_file, custom_objects)
1536      function = _trace_model_call(keras_model)
1537      concrete_func = function.get_concrete_function()
1538
1539      frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
1540          concrete_func, lower_control_flow=False)
1541      _set_tensor_shapes(frozen_func.inputs, input_shapes)
1542      self._keras_model = keras_model
1543      self._graph_def = frozen_func.graph.as_graph_def()
1544      self._input_tensors = frozen_func.inputs
1545      self._output_tensors = frozen_func.outputs
1546      self._debug_info_func = _build_debug_info_func(frozen_func.graph)
1547      return
1548
1549    # Handles Keras when Eager mode is disabled.
1550    keras_deps.get_clear_session_function()()
1551    keras_model = keras_deps.get_load_model_function()(
1552        model_file, custom_objects)
1553    sess = keras_deps.get_get_session_function()()
1554
1555    # Get input and output tensors.
1556    if input_arrays:
1557      input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
1558    else:
1559      input_tensors = keras_model.inputs
1560
1561    if output_arrays:
1562      output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
1563    else:
1564      output_tensors = keras_model.outputs
1565    _set_tensor_shapes(input_tensors, input_shapes)
1566
1567    graph_def = _freeze_graph(sess, input_tensors, output_tensors)
1568    self._keras_model = keras_model
1569    self._graph_def = graph_def
1570    self._input_tensors = input_tensors
1571    self._output_tensors = output_tensors
1572    self._debug_info_func = _build_debug_info_func(sess.graph)
1573
1574  def _convert_as_saved_model(self):
1575    """Converts a Keras model as a saved model.
1576
1577    Returns:
1578      The converted data in serialized format.
1579    """
1580    temp_dir = tempfile.mkdtemp()
1581    try:
1582      try:
1583        self._keras_model.save(temp_dir, save_format="tf")
1584      except Exception:  # pylint: disable=broad-except
1585        # When storing the given keras model to a saved model is failed, let's
1586        # use original keras model conversion pipeline.
1587        return None
1588      tag_set = set([_tag_constants.SERVING])
1589      signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
1590      result = _freeze_saved_model(temp_dir, None, None, None, tag_set,
1591                                   signature_key)
1592
1593      self.saved_model_dir = temp_dir
1594      self._saved_model_tags = tag_set
1595      self._saved_model_exported_names = [signature_key]
1596      self._parse_saved_model_args()
1597      if self.saved_model_dir:
1598        self._graph_def = result[0]
1599        self._input_tensors = result[1]
1600        self._output_tensors = result[2]
1601        self._debug_info_func = _build_debug_info_func(result[3])
1602        return super(TFLiteKerasModelConverter, self).convert()
1603    finally:
1604      shutil.rmtree(temp_dir, True)
1605
1606  def convert(self):
1607    """Converts a Keras model based on instance variables.
1608
1609    Returns:
1610      The converted data in serialized format. Either a TFLite Flatbuffer or a
1611      Graphviz graph depending on value in `output_format`.
1612
1613    Raises:
1614      ValueError:
1615        Input shape is not specified.
1616        None value for dimension in input_tensor.
1617    """
1618    saved_model_convert_result = self._convert_as_saved_model()
1619    if saved_model_convert_result:
1620      return saved_model_convert_result
1621
1622    return super(TFLiteKerasModelConverter, self).convert()
1623
1624
1625class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1):
1626  """Converts the given frozen graph def into TensorFlow Lite model."""
1627
1628  def __init__(self,
1629               graph_def,
1630               input_tensors,
1631               output_tensors,
1632               input_arrays_with_shape=None,
1633               output_arrays=None,
1634               experimental_debug_info_func=None):
1635    """Constructor for TFLiteConverter.
1636
1637    Args:
1638      graph_def: Frozen TensorFlow GraphDef.
1639      input_tensors: List of input tensors. Type and shape are computed using
1640        `foo.shape` and `foo.dtype`.
1641      output_tensors: List of output tensors (only .name is used from this).
1642      input_arrays_with_shape: Tuple of strings representing input tensor names
1643        and list of integers representing input shapes
1644        (e.g., [("foo", [1, 16, 16, 3])]). Use only when graph cannot be loaded
1645          into TensorFlow and when `input_tensors` and `output_tensors` are
1646          None. (default None)
1647      output_arrays: List of output tensors to freeze graph with. Use only when
1648        graph cannot be loaded into TensorFlow and when `input_tensors` and
1649        `output_tensors` are None. (default None)
1650      experimental_debug_info_func: An experimental function to retrieve the
1651        graph debug info for a set of nodes from the `graph_def`.
1652
1653    Raises:
1654      ValueError: Invalid arguments.
1655    """
1656    super(TFLiteFrozenGraphConverter,
1657          self).__init__(experimental_debug_info_func)
1658    self._graph_def = graph_def
1659    self._input_tensors = input_tensors
1660    self._output_tensors = output_tensors
1661
1662    # Attributes are used by models that cannot be loaded into TensorFlow.
1663    if not self._has_valid_tensors():
1664      if not input_arrays_with_shape or not output_arrays:
1665        raise ValueError(
1666            "If input_tensors and output_tensors are None, both "
1667            "input_arrays_with_shape and output_arrays must be defined.")
1668      self._input_arrays_with_shape = input_arrays_with_shape
1669      self._output_arrays = output_arrays
1670
1671    if input_tensors is not None and input_arrays_with_shape is not None:
1672      logging.warning("input_arrays_with_shape will be ignored when both the "
1673                      "given input_tensors and input_arrays_with_shape are not "
1674                      "None.")
1675
1676    if output_tensors is not None and output_arrays is not None:
1677      logging.warning("output_arrays will be ignored when both the given "
1678                      "output_tensors and output_arrays are not None.")
1679
1680
1681@_tf_export(v1=["lite.TFLiteConverter"])
1682class TFLiteConverter(TFLiteFrozenGraphConverter):
1683  """Convert a TensorFlow model into `output_format`.
1684
1685  This is used to convert from a TensorFlow GraphDef, SavedModel or tf.keras
1686  model into either a TFLite FlatBuffer or graph visualization.
1687
1688  Attributes:
1689    optimizations: Experimental flag, subject to change. Set of optimizations to
1690      apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
1691      set of values of type `tf.lite.Optimize`)
1692    representative_dataset: A generator function used for integer quantization
1693      where each generated sample has the same order, type and shape as the
1694      inputs to the model. Usually, this is a small subset of a few hundred
1695      samples randomly chosen, in no particular order, from the training or
1696      evaluation dataset. This is an optional attribute, but required for full
1697      integer quantization, i.e, if `tf.int8` is the only supported type in
1698      `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
1699      (default None)
1700    target_spec: Experimental flag, subject to change. Specifications of target
1701      device, including supported ops set, supported types and a set of user's
1702      defined TensorFlow operators required in the TensorFlow Lite runtime.
1703      Refer to `tf.lite.TargetSpec`.
1704    inference_type: Data type of numeric arrays, excluding the input layer.
1705      (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
1706    inference_input_type: Data type of the numeric arrays in the input layer. If
1707      `inference_input_type` is in {tf.int8, tf.uint8}, then
1708      `quantized_input_stats` must be provided. (default is the value assigned
1709      to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
1710    inference_output_type: Data type of the numeric arrays in the output layer.
1711      (default is the value assigned to `inference_type`, must be in
1712      {tf.float32, tf.int8, tf.uint8})
1713    quantized_input_stats: Map of input tensor names to a tuple of floats
1714      representing the mean and standard deviation of the training data.
1715      (e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
1716        or tf.uint8. (default None)
1717    default_ranges_stats: Tuple of integers (min, max) representing range values
1718      for all numeric arrays without a specified range. Intended for
1719      experimenting with quantization via "dummy quantization". (default None)
1720    allow_custom_ops: Boolean indicating whether to allow custom operations.
1721      When False any unknown operation is an error. When True, custom ops are
1722      created for any op that is unknown. The developer will need to provide
1723      these to the TensorFlow Lite runtime with a custom resolver. (default
1724      False)
1725    drop_control_dependency: Boolean indicating whether to drop control
1726      dependencies silently. This is due to TFLite not supporting control
1727      dependencies. (default True)
1728    reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
1729      nodes in unexpected locations. Used when the location of the FakeQuant
1730      nodes is preventing graph transformations necessary to convert the graph.
1731      Results in a graph that differs from the quantized training graph,
1732      potentially causing differing arithmetic behavior. (default False)
1733    change_concat_input_ranges: Boolean to change behavior of min/max ranges for
1734      inputs and outputs of the concat operator for quantized models. Changes
1735      the ranges of concat operator overlap when true. (default False)
1736    output_format: Output file format. (default
1737      tf.compat.v1.lite.constants.TFLITE, must be in
1738      {tf.compat.v1.lite.constants.TFLITE,
1739      tf.compat.v1.lite.constants.GRAPHVIZ_DOT})
1740    dump_graphviz_dir: Full filepath of folder to dump the graphs at various
1741      stages of processing GraphViz .dot files. Preferred over
1742      `output_format=tf.compat.v1.lite.constants.GRAPHVIZ_DOT` in order to keep
1743      the requirements of the output file. (default None)
1744    dump_graphviz_video: Boolean indicating whether to dump the GraphViz .dot
1745      files after every graph transformation. Requires the `dump_graphviz_dir`
1746      flag to be specified. (default False)
1747    conversion_summary_dir: Full path of the directory to store conversion logs.
1748      (default None)
1749    target_ops: Deprecated. Please use `target_spec.supported_ops` instead.
1750    post_training_quantize: Deprecated. Please use `optimizations` instead and
1751      set it to `{tf.lite.Optimize.DEFAULT}`. (default False)
1752    experimental_new_converter: Experimental flag, subject to change. Enables
1753      MLIR-based conversion instead of TOCO conversion. (default True)
1754    experimental_new_quantizer: Experimental flag, subject to change. Enables
1755      MLIR-based quantization conversion instead of Flatbuffer-based conversion.
1756      (default False)
1757
1758  Example usage:
1759
1760    ```python
1761    # Converting a GraphDef from session.
1762    converter = tf.compat.v1.lite.TFLiteConverter.from_session(
1763      sess, in_tensors, out_tensors)
1764    tflite_model = converter.convert()
1765    open("converted_model.tflite", "wb").write(tflite_model)
1766
1767    # Converting a GraphDef from file.
1768    converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
1769      graph_def_file, input_arrays, output_arrays)
1770    tflite_model = converter.convert()
1771    open("converted_model.tflite", "wb").write(tflite_model)
1772
1773    # Converting a SavedModel.
1774    converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
1775        saved_model_dir)
1776    tflite_model = converter.convert()
1777    open("converted_model.tflite", "wb").write(tflite_model)
1778
1779    # Converting a tf.keras model.
1780    converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(
1781        keras_model)
1782    tflite_model = converter.convert()
1783    open("converted_model.tflite", "wb").write(tflite_model)
1784    ```
1785  """
1786
1787  # pylint: disable=useless-super-delegation
1788  def __init__(self,
1789               graph_def,
1790               input_tensors,
1791               output_tensors,
1792               input_arrays_with_shape=None,
1793               output_arrays=None,
1794               experimental_debug_info_func=None):
1795    """Constructor for TFLiteConverter.
1796
1797    Args:
1798      graph_def: Frozen TensorFlow GraphDef.
1799      input_tensors: List of input tensors. Type and shape are computed using
1800        `foo.shape` and `foo.dtype`.
1801      output_tensors: List of output tensors (only .name is used from this).
1802      input_arrays_with_shape: Tuple of strings representing input tensor names
1803        and list of integers representing input shapes
1804        (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
1805          into TensorFlow and when `input_tensors` and `output_tensors` are
1806          None. (default None)
1807      output_arrays: List of output tensors to freeze graph with. Use only when
1808        graph cannot be loaded into TensorFlow and when `input_tensors` and
1809        `output_tensors` are None. (default None)
1810      experimental_debug_info_func: An experimental function to retrieve the
1811        graph debug info for a set of nodes from the `graph_def`.
1812
1813    Raises:
1814      ValueError: Invalid arguments.
1815    """
1816    super(TFLiteConverter,
1817          self).__init__(graph_def, input_tensors, output_tensors,
1818                         input_arrays_with_shape, output_arrays,
1819                         experimental_debug_info_func)
1820
1821  @classmethod
1822  def from_session(cls, sess, input_tensors, output_tensors):
1823    """Creates a TFLiteConverter class from a TensorFlow Session.
1824
1825    Args:
1826      sess: TensorFlow Session.
1827      input_tensors: List of input tensors. Type and shape are computed using
1828        `foo.shape` and `foo.dtype`.
1829      output_tensors: List of output tensors (only .name is used from this).
1830
1831    Returns:
1832      TFLiteConverter class.
1833    """
1834    graph_def = _freeze_graph(sess, input_tensors, output_tensors)
1835    return cls(
1836        graph_def,
1837        input_tensors,
1838        output_tensors,
1839        experimental_debug_info_func=_build_debug_info_func(sess.graph))
1840
1841  @classmethod
1842  def from_frozen_graph(cls,
1843                        graph_def_file,
1844                        input_arrays,
1845                        output_arrays,
1846                        input_shapes=None):
1847    """Creates a TFLiteConverter class from a file containing a frozen GraphDef.
1848
1849    Args:
1850      graph_def_file: Full filepath of file containing frozen GraphDef.
1851      input_arrays: List of input tensors to freeze graph with.
1852      output_arrays: List of output tensors to freeze graph with.
1853      input_shapes: Dict of strings representing input tensor names to list of
1854        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
1855        Automatically determined when input shapes is None (e.g., {"foo" :
1856          None}). (default None)
1857
1858    Returns:
1859      TFLiteConverter class.
1860
1861    Raises:
1862      IOError:
1863        File not found.
1864        Unable to parse input file.
1865      ValueError:
1866        The graph is not frozen.
1867        input_arrays or output_arrays contains an invalid tensor name.
1868        input_shapes is not correctly defined when required
1869    """
1870    with _ops.Graph().as_default():
1871      with _session.Session() as sess:
1872        # Read GraphDef from file.
1873        if not _file_io.file_exists(graph_def_file):
1874          raise IOError("File '{0}' does not exist.".format(graph_def_file))
1875        with _file_io.FileIO(graph_def_file, "rb") as f:
1876          file_content = f.read()
1877
1878        try:
1879          graph_def = _graph_pb2.GraphDef()
1880          graph_def.ParseFromString(file_content)
1881        except (_text_format.ParseError, DecodeError):
1882          try:
1883            print("Ignore 'tcmalloc: large alloc' warnings.")
1884
1885            if not isinstance(file_content, str):
1886              if PY2:
1887                file_content = six.ensure_binary(file_content, "utf-8")
1888              else:
1889                file_content = six.ensure_text(file_content, "utf-8")
1890            graph_def = _graph_pb2.GraphDef()
1891            _text_format.Merge(file_content, graph_def)
1892          except (_text_format.ParseError, DecodeError):
1893            raise IOError(
1894                "Unable to parse input file '{}'.".format(graph_def_file))
1895
1896        # Handles models with custom TFLite ops that cannot be resolved in
1897        # TensorFlow.
1898        load_model_in_session = True
1899        try:
1900          _import_graph_def(graph_def, name="")
1901        except _NotFoundError:
1902          load_model_in_session = False
1903
1904        if load_model_in_session:
1905          # Check if graph is frozen.
1906          if not _is_frozen_graph(sess):
1907            raise ValueError("Please freeze the graph using freeze_graph.py.")
1908
1909          # Get input and output tensors.
1910          input_tensors = _get_tensors_from_tensor_names(
1911              sess.graph, input_arrays)
1912          output_tensors = _get_tensors_from_tensor_names(
1913              sess.graph, output_arrays)
1914          _set_tensor_shapes(input_tensors, input_shapes)
1915
1916          return cls(sess.graph_def, input_tensors, output_tensors)
1917        else:
1918          if not input_shapes:
1919            raise ValueError("input_shapes must be defined for this model.")
1920          if set(input_arrays) != set(input_shapes.keys()):
1921            raise ValueError("input_shapes must contain a value for each item "
1922                             "in input_array.")
1923
1924          input_arrays_with_shape = [
1925              (name, input_shapes[name]) for name in input_arrays
1926          ]
1927          return cls(
1928              graph_def,
1929              input_tensors=None,
1930              output_tensors=None,
1931              input_arrays_with_shape=input_arrays_with_shape,
1932              output_arrays=output_arrays)
1933
1934  @classmethod
1935  def from_saved_model(cls,
1936                       saved_model_dir,
1937                       input_arrays=None,
1938                       input_shapes=None,
1939                       output_arrays=None,
1940                       tag_set=None,
1941                       signature_key=None):
1942    """Creates a TFLiteConverter class from a SavedModel.
1943
1944    Args:
1945      saved_model_dir: SavedModel directory to convert.
1946      input_arrays: List of input tensors to freeze graph with. Uses input
1947        arrays from SignatureDef when none are provided. (default None)
1948      input_shapes: Dict of strings representing input tensor names to list of
1949        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
1950        Automatically determined when input shapes is None (e.g., {"foo" :
1951          None}). (default None)
1952      output_arrays: List of output tensors to freeze graph with. Uses output
1953        arrays from SignatureDef when none are provided. (default None)
1954      tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
1955        analyze. All tags in the tag set must be present. (default
1956        {tf.saved_model.SERVING})
1957      signature_key: Key identifying SignatureDef containing inputs and outputs.
1958        (default tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
1959
1960    Returns:
1961      TFLiteConverter class.
1962    """
1963    if tag_set is None:
1964      tag_set = set([_tag_constants.SERVING])
1965    if signature_key is None:
1966      signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
1967
1968    saved_model_converter = TFLiteSavedModelConverter(saved_model_dir, tag_set,
1969                                                      [signature_key])
1970    if saved_model_converter.saved_model_dir:
1971      return saved_model_converter
1972
1973    result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
1974                                 output_arrays, tag_set, signature_key)
1975
1976    return cls(
1977        graph_def=result[0],
1978        input_tensors=result[1],
1979        output_tensors=result[2],
1980        experimental_debug_info_func=_build_debug_info_func(result[3]))
1981
1982  @classmethod
1983  def from_keras_model_file(cls,
1984                            model_file,
1985                            input_arrays=None,
1986                            input_shapes=None,
1987                            output_arrays=None,
1988                            custom_objects=None):
1989    """Creates a TFLiteConverter class from a tf.keras model file.
1990
1991    Args:
1992      model_file: Full filepath of HDF5 file containing the tf.keras model.
1993      input_arrays: List of input tensors to freeze graph with. Uses input
1994        arrays from SignatureDef when none are provided. (default None)
1995      input_shapes: Dict of strings representing input tensor names to list of
1996        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
1997        Automatically determined when input shapes is None (e.g., {"foo" :
1998          None}). (default None)
1999      output_arrays: List of output tensors to freeze graph with. Uses output
2000        arrays from SignatureDef when none are provided. (default None)
2001      custom_objects: Dict mapping names (strings) to custom classes or
2002        functions to be considered during model deserialization. (default None)
2003
2004    Returns:
2005      TFLiteConverter class.
2006    """
2007    return TFLiteKerasModelConverter(model_file, input_arrays, input_shapes,
2008                                     output_arrays, custom_objects)
2009
2010  # pylint: disable=useless-super-delegation
2011  def convert(self):
2012    """Converts a TensorFlow GraphDef based on instance variables.
2013
2014    Returns:
2015      The converted data in serialized format. Either a TFLite Flatbuffer or a
2016      Graphviz graph depending on value in `output_format`.
2017
2018    Raises:
2019      ValueError:
2020        Input shape is not specified.
2021        None value for dimension in input_tensor.
2022    """
2023    return super(TFLiteConverter, self).convert()
2024
2025
2026@_tf_export(v1=["lite.TocoConverter"])
2027class TocoConverter(object):
2028  """Convert a TensorFlow model into `output_format` using TOCO.
2029
2030  This class has been deprecated. Please use `lite.TFLiteConverter` instead.
2031  """
2032
2033  @classmethod
2034  @_deprecation.deprecated(None,
2035                           "Use `lite.TFLiteConverter.from_session` instead.")
2036  def from_session(cls, sess, input_tensors, output_tensors):
2037    """Creates a TocoConverter class from a TensorFlow Session."""
2038    return TFLiteConverter.from_session(sess, input_tensors, output_tensors)
2039
2040  @classmethod
2041  @_deprecation.deprecated(
2042      None, "Use `lite.TFLiteConverter.from_frozen_graph` instead.")
2043  def from_frozen_graph(cls,
2044                        graph_def_file,
2045                        input_arrays,
2046                        output_arrays,
2047                        input_shapes=None):
2048    """Creates a TocoConverter class from a file containing a frozen graph."""
2049    return TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays,
2050                                             output_arrays, input_shapes)
2051
2052  @classmethod
2053  @_deprecation.deprecated(
2054      None, "Use `lite.TFLiteConverter.from_saved_model` instead.")
2055  def from_saved_model(cls,
2056                       saved_model_dir,
2057                       input_arrays=None,
2058                       input_shapes=None,
2059                       output_arrays=None,
2060                       tag_set=None,
2061                       signature_key=None):
2062    """Creates a TocoConverter class from a SavedModel."""
2063    return TFLiteConverter.from_saved_model(saved_model_dir, input_arrays,
2064                                            input_shapes, output_arrays,
2065                                            tag_set, signature_key)
2066
2067  @classmethod
2068  @_deprecation.deprecated(
2069      None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.")
2070  def from_keras_model_file(cls,
2071                            model_file,
2072                            input_arrays=None,
2073                            input_shapes=None,
2074                            output_arrays=None):
2075    """Creates a TocoConverter class from a tf.keras model file."""
2076    return TFLiteConverter.from_keras_model_file(model_file, input_arrays,
2077                                                 input_shapes, output_arrays)
2078