1# Lint as: python2, python3 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""TensorFlow Lite tooling helper functionality.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import enum 23import shutil 24import tempfile 25import warnings 26 27from absl import logging 28import six 29from six import PY2 30 31from google.protobuf import text_format as _text_format 32from google.protobuf.message import DecodeError 33from tensorflow.core.framework import graph_pb2 as _graph_pb2 34from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn # pylint: disable=unused-import 35from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell # pylint: disable=unused-import 36from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell # pylint: disable=unused-import 37from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op # pylint: disable=unused-import 38from tensorflow.lite.experimental.tensorboard.ops_util import get_potentially_supported_ops # pylint: disable=unused-import 39from tensorflow.lite.python import lite_constants as constants 40from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import 41from tensorflow.lite.python.convert import convert_saved_model as _convert_saved_model 42from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import 43from tensorflow.lite.python.convert import mlir_quantize as _mlir_quantize 44from tensorflow.lite.python.convert import mlir_sparsify as _mlir_sparsify 45from tensorflow.lite.python.convert import OpsSet 46from tensorflow.lite.python.convert import toco_convert # pylint: disable=unused-import 47from tensorflow.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def 48from tensorflow.lite.python.convert import toco_convert_impl as _toco_convert_impl 49from tensorflow.lite.python.convert import toco_convert_protos # pylint: disable=unused-import 50from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model 51from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import 52from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import 53from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import 54from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted 55from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import 56from tensorflow.lite.python.optimize import calibrator as _calibrator 57from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func 58from tensorflow.lite.python.util import convert_debug_info_func as _convert_debug_info_func 59from tensorflow.lite.python.util import freeze_graph as _freeze_graph 60from tensorflow.lite.python.util import get_debug_info as _get_debug_info 61from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config 62from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name 63from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names 64from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name 65from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph 66from tensorflow.lite.python.util import model_input_signature as _model_input_signature 67from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type 68from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations 69from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes 70from tensorflow.lite.python.util import trace_model_call as _trace_model_call 71from tensorflow.python.client import session as _session 72from tensorflow.python.eager import context 73from tensorflow.python.eager import def_function as _def_function 74from tensorflow.python.eager import function as _function 75from tensorflow.python.framework import convert_to_constants as _convert_to_constants 76from tensorflow.python.framework import dtypes as _dtypes 77from tensorflow.python.framework import ops as _ops 78from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError 79from tensorflow.python.framework.importer import import_graph_def as _import_graph_def 80from tensorflow.python.lib.io import file_io as _file_io 81from tensorflow.python.saved_model import loader_impl as _loader_impl 82from tensorflow.python.saved_model import signature_constants as _signature_constants 83from tensorflow.python.saved_model import tag_constants as _tag_constants 84from tensorflow.python.saved_model.load import load as _load 85from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info 86from tensorflow.python.util import deprecation as _deprecation 87from tensorflow.python.util import keras_deps 88from tensorflow.python.util.tf_export import tf_export as _tf_export 89 90 91@_tf_export("lite.Optimize") 92class Optimize(enum.Enum): 93 """Enum defining the optimizations to apply when generating a tflite model. 94 95 DEFAULT 96 Default optimization strategy that quantizes model weights. Enhanced 97 optimizations are gained by providing a representative dataset that 98 quantizes biases and activations as well. 99 Converter will do its best to reduce size and latency, while minimizing 100 the loss in accuracy. 101 102 OPTIMIZE_FOR_SIZE 103 Deprecated. Does the same as DEFAULT. 104 105 OPTIMIZE_FOR_LATENCY 106 Deprecated. Does the same as DEFAULT. 107 108 EXPERIMENTAL_SPARSITY 109 Experimental flag, subject to change. 110 111 Enable optimization by taking advantage of the sparse model weights 112 trained with pruning. 113 114 The converter will inspect the sparsity pattern of the model weights and 115 do its best to improve size and latency. 116 The flag can be used alone to optimize float32 models with sparse weights. 117 It can also be used together with the DEFAULT optimization mode to 118 optimize quantized models with sparse weights. 119 """ 120 121 # Default optimization strategy that quantizes model weights. Enhanced 122 # optimizations are gained by providing a representative dataset that 123 # quantizes biases and activations as well. 124 # Converter will do its best to reduce size and latency, while minimizing 125 # the loss in accuracy. 126 DEFAULT = "DEFAULT" 127 128 # Deprecated. Does the same as DEFAULT. 129 OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE" 130 131 # Deprecated. Does the same as DEFAULT. 132 OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY" 133 134 # Experimental flag, subject to change. 135 # Enable optimization by taking advantage of the sparse model weights trained 136 # with pruning. 137 # 138 # The converter will inspect the sparsity pattern of the model weights and do 139 # its best to improve size and latency. 140 # The flag can be used alone to optimize float32 models with sparse weights. 141 # It can also be used together with the DEFAULT optimization mode to optimize 142 # quantized models with sparse weights. 143 # TODO(b/161560631): Add log message when this optimization is applied. 144 EXPERIMENTAL_SPARSITY = "EXPERIMENTAL_SPARSITY" 145 146 def __str__(self): 147 return str(self.value) 148 149 150@_tf_export("lite.RepresentativeDataset") 151class RepresentativeDataset(object): 152 """Representative dataset used to optimize the model. 153 154 This is a generator function that provides a small dataset to calibrate or 155 estimate the range, i.e, (min, max) of all floating-point arrays in the model 156 (such as model input, activation outputs of intermediate layers, and model 157 output) for quantization. Usually, this is a small subset of a few hundred 158 samples randomly chosen, in no particular order, from the training or 159 evaluation dataset. 160 """ 161 162 def __init__(self, input_gen): 163 """Creates a representative dataset. 164 165 Args: 166 input_gen: A generator function that generates input samples for the 167 model and has the same order, type and shape as the inputs to the model. 168 Usually, this is a small subset of a few hundred samples randomly 169 chosen, in no particular order, from the training or evaluation dataset. 170 """ 171 self.input_gen = input_gen 172 173 174@_tf_export("lite.TargetSpec") 175class TargetSpec(object): 176 """Specification of target device used to optimize the model. 177 178 Attributes: 179 supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet` 180 options, where each option represents a set of operators supported by the 181 target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS})) 182 supported_types: Set of `tf.dtypes.DType` data types supported on the target 183 device. If initialized, optimization might be driven by the smallest type 184 in this set. (default set()) 185 experimental_select_user_tf_ops: Experimental flag, subject to change. Set 186 of user's TensorFlow operators' names that are required in the TensorFlow 187 Lite runtime. These ops will be exported as select TensorFlow ops in the 188 model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is 189 an advanced feature that should only be used if the client is using TF ops 190 that may not be linked in by default with the TF ops that are provided 191 when using the SELECT_TF_OPS path. The client is responsible for linking 192 these ops into the target runtime. 193 """ 194 195 def __init__(self, 196 supported_ops=None, 197 supported_types=None, 198 experimental_select_user_tf_ops=None): 199 if supported_ops is None: 200 supported_ops = {OpsSet.TFLITE_BUILTINS} 201 self.supported_ops = supported_ops 202 if supported_types is None: 203 supported_types = set() 204 self.supported_types = supported_types 205 if experimental_select_user_tf_ops is None: 206 self.experimental_select_user_tf_ops = set() 207 208 209class QuantizationMode(object): 210 """QuantizationMode determines the quantization type from user options.""" 211 212 def __init__(self, optimizations, target_spec, representative_dataset, 213 graph_def): 214 self._optimizations = optimizations 215 self._target_spec = target_spec 216 self._representative_dataset = representative_dataset 217 self._graph_def = graph_def 218 219 self._validate_int8_required() 220 221 # TODO(b/162537905): Refactor the following quantization functions - 222 # re-organize and refactor for better readability. 223 def post_training_int8_no_float(self): 224 return (self._any_optimization_enabled() and 225 self._is_int8_target_required() and 226 not self._is_int16x8_target_required() and 227 not self._is_allow_float() and 228 self._representative_dataset is not None) 229 230 def post_training_int8_allow_float(self): 231 return (self._any_optimization_enabled() and 232 not self._is_int16x8_target_required() and 233 self._representative_dataset is not None and 234 self._smallest_supported_type() == _dtypes.int8) 235 236 def is_post_training_integer_quantize_8(self): 237 return (self.post_training_int8_no_float() or 238 self.post_training_int8_allow_float()) 239 240 def is_post_training_integer_quantize_16x8(self): 241 return (self.post_training_int16x8_no_float() or 242 self.post_training_int16x8_allow_float()) 243 244 def is_integer_quantize(self): 245 return (self.is_post_training_integer_quantize_8() or 246 self.is_post_training_integer_quantize_16x8() or 247 self.is_training_time_int8_allow_float()) 248 249 def is_training_time_int8_allow_float(self): 250 return (self._any_optimization_enabled() and 251 self.contains_training_quant_op()) 252 253 def post_training_int16x8_no_float(self): 254 return (self._any_optimization_enabled() and 255 not self._is_int8_target_required() and 256 self._is_int16x8_target_required() and 257 not self._is_allow_float() and 258 self._representative_dataset is not None) 259 260 def post_training_int16x8_allow_float(self): 261 return (self._any_optimization_enabled() and 262 self._is_int16x8_target_required() and 263 self._is_allow_float()) 264 265 def post_training_dynamic_range_int8(self): 266 # Post-training dynamic range quantization is only enabled if post-training 267 # int8 quantization and training time quantization was not done. 268 return (self._any_optimization_enabled() and 269 self._representative_dataset is None and 270 not self.contains_training_quant_op() and 271 self._smallest_supported_type() == _dtypes.int8) 272 273 def post_training_fp16(self): 274 return (self._any_optimization_enabled() and 275 self._smallest_supported_type() == _dtypes.float16) 276 277 def fp32_execution(self): 278 """If none of the above are true.""" 279 return not (self.is_integer_quantize() or 280 self.post_training_dynamic_range_int8() or 281 self.post_training_fp16()) 282 283 def activations_type(self): 284 return _dtypes.int16 if self._is_int16x8_target_required() \ 285 else _dtypes.int8 286 287 def converter_flags(self, inference_ty=None, inference_input_ty=None): 288 """Flags to the converter.""" 289 290 if self.is_integer_quantize(): 291 return { 292 "inference_type": inference_ty if inference_ty else \ 293 self.activations_type(), 294 "inference_input_type": _dtypes.float32, 295 "post_training_quantize": False, # disable dynamic range quantization 296 "quantize_to_float16": False # disable float16 quantization 297 } 298 elif self.post_training_dynamic_range_int8(): 299 return { 300 "inference_type": _dtypes.float32, 301 "inference_input_type": _dtypes.float32, 302 "post_training_quantize": True, # enable dynamic range quantization 303 "quantize_to_float16": False # disable float16 quantization 304 } 305 elif self.post_training_fp16(): 306 return { 307 "inference_type": _dtypes.float32, 308 "inference_input_type": _dtypes.float32, 309 "post_training_quantize": True, 310 "quantize_to_float16": True # enable float16 quantization 311 } 312 else: 313 # Note this might still trigger (uint8) quantization to be compatible with 314 # TOCO. 315 return { 316 "inference_type": inference_ty if inference_ty else _dtypes.float32, 317 "inference_input_type": inference_input_ty, 318 "post_training_quantize": False, # enable dynamic range quantization 319 "quantize_to_float16": False # disable float16 quantization 320 } 321 322 def quantizer_flags(self, input_ty=None, output_ty=None): 323 """Default flags to the TFMOT quantizer.""" 324 325 inference_input_type = input_ty if input_ty else _dtypes.float32 326 inference_output_type = output_ty if output_ty else _dtypes.float32 327 328 if self.post_training_int8_no_float() \ 329 or self.post_training_int16x8_no_float(): 330 return True, { 331 "inference_input_type": inference_input_type, 332 "inference_output_type": inference_output_type, 333 "activations_type": self.activations_type(), 334 "allow_float": False 335 } 336 elif self.post_training_int8_allow_float() \ 337 or self.post_training_int16x8_allow_float(): 338 return True, { 339 "inference_input_type": inference_input_type, 340 "inference_output_type": inference_output_type, 341 "activations_type": self.activations_type(), 342 "allow_float": True 343 } 344 else: 345 return False, None 346 347 def flags_modify_model_io_type(self, input_ty=None, output_ty=None): 348 """Flags for modifying the input and output type of a tflite model.""" 349 350 if self.is_integer_quantize(): 351 return { 352 "inference_input_type": input_ty if input_ty else _dtypes.float32, 353 "inference_output_type": output_ty if output_ty else _dtypes.float32, 354 } 355 else: 356 return None 357 358 # Below are helpers for the above functions. 359 360 def _validate_int8_required(self): 361 """Int8 mode requires certain parameters to exist and be compatible.""" 362 if not self._is_int8_target_required(): 363 return 364 365 if self._target_spec.supported_types and (self._smallest_supported_type() != 366 _dtypes.int8): 367 raise ValueError("TFLITE_BUILTINS_INT8 requires smallest supported " 368 "type to be INT8.") 369 370 if self._representative_dataset: 371 if not isinstance(self._representative_dataset, RepresentativeDataset): 372 self._representative_dataset = RepresentativeDataset( 373 self._representative_dataset) 374 if self._representative_dataset.input_gen is None: 375 raise ValueError( 376 "Provide an input generator for representative_dataset") 377 else: 378 # TODO(b/150661651): Relax this check for QAT. 379 raise ValueError("representative_dataset is required when specifying " 380 "TFLITE_BUILTINS_INT8 or INT8 supported types.") 381 382 def _is_int8_target_required(self): 383 return (OpsSet.TFLITE_BUILTINS_INT8 in set( 384 self._target_spec.supported_ops)) or (set( 385 self._target_spec.supported_types) == set([_dtypes.int8])) 386 387 def _is_int16x8_target_required(self): 388 return (OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 389 in set(self._target_spec.supported_ops)) 390 391 def _is_allow_float(self): 392 return (OpsSet.TFLITE_BUILTINS in set( 393 self._target_spec.supported_ops)) or (OpsSet.SELECT_TF_OPS in set( 394 self._target_spec.supported_ops)) 395 396 def _any_optimization_enabled(self): 397 return bool( 398 set(self._optimizations).intersection([ 399 Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE, 400 Optimize.DEFAULT 401 ])) 402 403 def _smallest_supported_type(self): 404 if self._target_spec.supported_types: 405 return min(self._target_spec.supported_types, key=lambda x: x.size) 406 else: 407 # The default smallest supported type is INT8. 408 return _dtypes.int8 409 410 def contains_training_quant_op(self): 411 """Checks if the graph contains any training-time quantization ops.""" 412 training_quant_ops = frozenset({ 413 "FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxVarsPerChannel", 414 "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsPerChannel", 415 "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3" 416 }) 417 418 for node_def in self._graph_def.node: 419 if node_def.op in training_quant_ops: 420 return True 421 for function in self._graph_def.library.function: 422 for node_def in function.node_def: 423 if node_def.op in training_quant_ops: 424 return True 425 return False 426 427 428class TFLiteConverterBase(object): 429 """Converter subclass to share functionality between V1 and V2 converters.""" 430 431 def __init__(self): 432 self.optimizations = set() 433 self.representative_dataset = None 434 self.target_spec = TargetSpec() 435 self.allow_custom_ops = False 436 self.experimental_new_converter = True 437 self.experimental_new_quantizer = False 438 self._experimental_new_quantizer = None 439 self._experimental_calibrate_only = False 440 self._experimental_sparsify_model = False 441 self._debug_info = None # contains the stack traces of all the original 442 # nodes in the `GraphDef` to the converter. 443 self.saved_model_dir = None 444 self._saved_model_tags = None 445 self._saved_model_version = 0 446 self._saved_model_exported_names = [] 447 448 def _grappler_config(self, optimizers=None): 449 """Creates a tf.compat.v1.ConfigProto for configuring Grappler. 450 451 Args: 452 optimizers: List of strings that represents the list of optimizers. 453 454 Returns: 455 tf.ConfigProto. 456 """ 457 if not optimizers: 458 optimizers = [] 459 # MLIR converter will take care of constant folding instead of grappler. 460 if not self.experimental_new_converter: 461 optimizers.append("constfold") 462 463 is_only_flex_enabled = ( 464 set([OpsSet.SELECT_TF_OPS]) == set(self.target_spec.supported_ops)) 465 if is_only_flex_enabled: 466 # The layout optimizer turns NHCW to NCHW. This provides performance 467 # optimizations when Flex mode is enabled. However, this is not compatible 468 # with builtin ops. 469 optimizers.append("layout") 470 return _get_grappler_config(optimizers) 471 472 def _calibrate_quantize_model(self, result, inference_input_type, 473 inference_output_type, activations_type, 474 allow_float): 475 """Calibrate and quantize the model.""" 476 if not isinstance(self.representative_dataset, RepresentativeDataset): 477 self.representative_dataset = RepresentativeDataset( 478 self.representative_dataset) 479 480 # Add intermediate tensors to the model if needed. 481 result = _calibrator.add_intermediate_tensors(result) 482 calibrate_quantize = _calibrator.Calibrator(result) 483 if self._experimental_calibrate_only or self.experimental_new_quantizer: 484 calibrated = calibrate_quantize.calibrate( 485 self.representative_dataset.input_gen) 486 487 if self._experimental_calibrate_only: 488 return calibrated 489 elif self.experimental_new_quantizer and ( 490 activations_type != _dtypes.int16): 491 # TODO(b/175659372): remove the activations_type restriction and enable 492 # it for all the activation types. 493 return _mlir_quantize(calibrated) 494 else: 495 return calibrate_quantize.calibrate_and_quantize( 496 self.representative_dataset.input_gen, inference_input_type, 497 inference_output_type, allow_float, activations_type) 498 499 def _is_unknown_shapes_allowed(self): 500 # Unknown dimensions are only allowed with the new converter. 501 return self.experimental_new_converter 502 503 def _get_base_converter_args(self): 504 """Returns the base converter args. 505 506 Returns: 507 {key str: val} 508 """ 509 args = { 510 "input_format": constants.TENSORFLOW_GRAPHDEF, 511 "allow_custom_ops": self.allow_custom_ops, 512 "debug_info": self._debug_info, 513 "target_ops": self.target_spec.supported_ops, 514 "enable_mlir_converter": self.experimental_new_converter, 515 "select_user_tf_ops": self.target_spec.experimental_select_user_tf_ops, 516 } 517 518 if self.saved_model_dir: 519 args.update({ 520 "saved_model_dir": self.saved_model_dir, 521 "saved_model_version": self._saved_model_version, 522 "saved_model_tags": self._saved_model_tags, 523 "saved_model_exported_names": self._saved_model_exported_names, 524 }) 525 526 return args 527 528 def _contains_function_with_implements_attr(self, saved_model_proto): 529 meta_graph = saved_model_proto.meta_graphs[0] 530 for function in meta_graph.graph_def.library.function: 531 if function.attr.get("_implements", None) or function.attr.get( 532 "api_implements", None): 533 return True 534 return False 535 536 def _parse_saved_model_args(self, always_enable_saved_model_import=False): 537 """Parses SavedModel arguments from the given Keras/RNN SavedModel. 538 539 Args: 540 always_enable_saved_model_import: Bool. When the value is true, it enables 541 MLIR saved model import path regardless of checking the conditions. 542 """ 543 if not self.experimental_new_converter: 544 self.saved_model_dir = None 545 return 546 if self.saved_model_dir: 547 try: 548 saved_model_proto, _ = ( 549 _parse_saved_model_with_debug_info(self.saved_model_dir)) 550 except OSError: 551 # If it fails to read the given saved model, it will fall back to the 552 # frozen graph def path. 553 self.saved_model_dir = None 554 return 555 if (not always_enable_saved_model_import and 556 not self._contains_function_with_implements_attr(saved_model_proto)): 557 self.saved_model_dir = None 558 return 559 560 if not self._saved_model_exported_names: 561 self._saved_model_exported_names = [] 562 self._saved_model_version = saved_model_proto.saved_model_schema_version 563 if self._saved_model_version == 0: 564 self.saved_model_dir = None 565 logging.warning("SavedModel schema version is zero.") 566 return 567 if self._saved_model_version not in [1, 2]: 568 raise ValueError("SavedModel file format({0}) is not supported".format( 569 self._saved_model_version)) 570 571 def _sparsify_model(self): 572 return Optimize.EXPERIMENTAL_SPARSITY in self.optimizations 573 574 def _validate_experimental_new_quantizer_flag(self): 575 if self._experimental_new_quantizer is not None: 576 raise ValueError("Please use 'experimental_new_quantizer' instead.") 577 578 579class TFLiteConverterBaseV2(TFLiteConverterBase): 580 """Converter subclass to share functionality between V2 converters.""" 581 582 def __init__(self): 583 """Constructor for TFLiteConverter.""" 584 super(TFLiteConverterBaseV2, self).__init__() 585 self.inference_input_type = _dtypes.float32 586 self.inference_output_type = _dtypes.float32 587 588 def _validate_inference_input_output_types(self, quant_mode): 589 """Validate inference_input_type and inference_output_type flags.""" 590 default_types = [_dtypes.float32] 591 # We support integer input/output for integer quantized models only. 592 if quant_mode.is_integer_quantize(): 593 if quant_mode.is_post_training_integer_quantize_16x8(): 594 all_types = default_types + [_dtypes.int16] 595 else: 596 all_types = default_types + [_dtypes.int8, _dtypes.uint8] 597 if self.inference_input_type not in all_types or \ 598 self.inference_output_type not in all_types: 599 all_types_names = ["tf." + t.name for t in all_types] 600 raise ValueError("The inference_input_type and inference_output_type " 601 "must be in {}.".format(all_types_names)) 602 elif self.inference_input_type not in default_types or \ 603 self.inference_output_type not in default_types: 604 raise ValueError("The inference_input_type and inference_output_type " 605 "must be tf.float32.") 606 607 def convert(self, graph_def, input_tensors, output_tensors): 608 """Converts a TensorFlow GraphDef based on instance variables. 609 610 Args: 611 graph_def: Frozen TensorFlow GraphDef. 612 input_tensors: List of input tensors. Type and shape are computed using 613 `foo.shape` and `foo.dtype`. 614 output_tensors: List of output tensors (only .name is used from this). 615 616 Returns: 617 The converted data in serialized format. 618 619 Raises: 620 ValueError: 621 No concrete functions is specified. 622 Multiple concrete functions are specified. 623 Input shape is not specified. 624 Invalid quantization parameters. 625 """ 626 quant_mode = QuantizationMode(self.optimizations, self.target_spec, 627 self.representative_dataset, graph_def) 628 629 self._validate_inference_input_output_types(quant_mode) 630 self._validate_experimental_new_quantizer_flag() 631 632 if not self._is_unknown_shapes_allowed(): 633 # Checks dimensions in input tensor. 634 for tensor in input_tensors: 635 # Note that shape_list might be empty for scalar shapes. 636 shape_list = tensor.shape.as_list() 637 if None in shape_list[1:]: 638 raise ValueError( 639 "None is only supported in the 1st dimension. Tensor '{0}' has " 640 "invalid shape '{1}'.".format( 641 _get_tensor_name(tensor), shape_list)) 642 elif shape_list and shape_list[0] is None: 643 # Set the batch size to 1 if undefined. 644 shape = tensor.shape.as_list() 645 shape[0] = 1 646 tensor.set_shape(shape) 647 648 if self._trackable_obj is None: 649 self._debug_info = _get_debug_info( 650 _build_debug_info_func(self._funcs[0].graph), graph_def) 651 else: 652 self._debug_info = _get_debug_info( 653 _convert_debug_info_func(self._trackable_obj.graph_debug_info), 654 graph_def) 655 656 converter_kwargs = self._get_base_converter_args() 657 converter_kwargs.update(quant_mode.converter_flags()) 658 if not self.experimental_new_converter: 659 logging.warning( 660 "Please consider switching to the new converter by setting " 661 "experimental_new_converter=True. " 662 "The old converter (TOCO) is deprecated.") 663 else: 664 logging.info("Using new converter: If you encounter a problem " 665 "please file a bug. You can opt-out " 666 "by setting experimental_new_converter=False") 667 668 # Converts model. 669 result = _toco_convert_impl( 670 input_data=graph_def, 671 input_tensors=input_tensors, 672 output_tensors=output_tensors, 673 **converter_kwargs) 674 675 calibrate_and_quantize, flags = quant_mode.quantizer_flags() 676 if calibrate_and_quantize: 677 result = self._calibrate_quantize_model(result, **flags) 678 679 flags_modify_model_io_type = quant_mode.flags_modify_model_io_type( 680 self.inference_input_type, self.inference_output_type) 681 if flags_modify_model_io_type: 682 result = _modify_model_io_type(result, **flags_modify_model_io_type) 683 684 if self._sparsify_model(): 685 result = _mlir_sparsify(result) 686 687 return result 688 689 690class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2): 691 """Converts the given SavedModel into TensorFlow Lite model. 692 693 Attributes: 694 saved_model_dir: Directory of the SavedModel. 695 """ 696 697 def __init__(self, 698 saved_model_dir, 699 saved_model_tags=None, 700 saved_model_exported_names=None, 701 trackable_obj=None): 702 """Constructor for TFLiteConverter. 703 704 Args: 705 saved_model_dir: Directory of the SavedModel. 706 saved_model_tags: Set of tags identifying the MetaGraphDef within the 707 SavedModel to analyze. All tags in the tag set must be present. (default 708 {tf.saved_model.SERVING}). 709 saved_model_exported_names: Names to be exported when the saved model 710 import path is on. 711 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 712 reference to this object needs to be maintained so that Variables do not 713 get garbage collected since functions have a weak reference to 714 Variables. This is only required when the tf.AutoTrackable object is not 715 maintained by the user (e.g. `from_saved_model`). 716 """ 717 super(TFLiteSavedModelConverterV2, self).__init__() 718 self.saved_model_dir = saved_model_dir 719 self._saved_model_tags = saved_model_tags 720 self._saved_model_exported_names = saved_model_exported_names 721 self._trackable_obj = trackable_obj 722 self._parse_saved_model_args(always_enable_saved_model_import=True) 723 self._enable_tflite_resource_variables = False 724 725 def convert(self): 726 """Converts a TensorFlow GraphDef based on instance variables. 727 728 Returns: 729 The converted data in serialized format. 730 731 Raises: 732 ValueError: 733 No concrete functions is specified. 734 Multiple concrete functions are specified. 735 Input shape is not specified. 736 Invalid quantization parameters. 737 """ 738 graph = _ops.Graph() 739 saved_model = _loader_impl.SavedModelLoader(self.saved_model_dir) 740 saved_model.load_graph(graph, tags=self._saved_model_tags) 741 meta_graph = saved_model.get_meta_graph_def_from_tags( 742 self._saved_model_tags) 743 # If we can't use saved model importer, then fallback 744 # to frozen graph conversion path. 745 if self.saved_model_dir is None or not self.experimental_new_converter: 746 signature_def = meta_graph.signature_def[ 747 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 748 input_tensors = [ 749 graph.get_tensor_by_name(signature_def.inputs[key].name) 750 for key in signature_def.inputs 751 ] 752 output_tensors = [ 753 graph.get_tensor_by_name(signature_def.outputs[key].name) 754 for key in signature_def.outputs 755 ] 756 result = _freeze_saved_model( 757 self.saved_model_dir, None, None, None, self._saved_model_tags, 758 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) 759 graph_def = result[0] 760 # We make sure to clear the saved_model_dir as there is some 761 # legacy code down in the caller that checks this. 762 # TODO(b/162537905): Clean these indirect dependencies. 763 self.saved_model_dir = None 764 return super(TFLiteSavedModelConverterV2, 765 self).convert(graph_def, input_tensors, output_tensors) 766 767 if self._trackable_obj is None: 768 self._debug_info = _get_debug_info( 769 _build_debug_info_func(self._funcs[0].graph), meta_graph.graph_def) 770 else: 771 self._debug_info = _get_debug_info( 772 _convert_debug_info_func(self._trackable_obj.graph_debug_info), 773 meta_graph.graph_def) 774 775 # Get quantization options and do some sanity checks. 776 quant_mode = QuantizationMode(self.optimizations, self.target_spec, 777 self.representative_dataset, 778 meta_graph.graph_def) 779 self._validate_inference_input_output_types(quant_mode) 780 781 converter_kwargs = self._get_base_converter_args() 782 converter_kwargs.update(quant_mode.converter_flags()) 783 converter_kwargs.update({ 784 "enable_tflite_resource_variables": 785 self._enable_tflite_resource_variables 786 }) 787 788 result = _convert_saved_model(**converter_kwargs) 789 calibrate_and_quantize, flags = quant_mode.quantizer_flags() 790 if calibrate_and_quantize: 791 result = self._calibrate_quantize_model(result, **flags) 792 793 flags_modify_model_io_type = quant_mode.flags_modify_model_io_type( 794 self.inference_input_type, self.inference_output_type) 795 if flags_modify_model_io_type: 796 result = _modify_model_io_type(result, **flags_modify_model_io_type) 797 798 if self._sparsify_model(): 799 result = _mlir_sparsify(result) 800 801 return result 802 803 804class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2): 805 """Converts the given Keras model into TensorFlow Lite model.""" 806 807 def __init__(self, keras_model, trackable_obj=None): 808 """Constructor for TFLiteConverter. 809 810 Args: 811 keras_model: tf.Keras.Model. 812 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 813 reference to this object needs to be maintained so that Variables do not 814 get garbage collected since functions have a weak reference to 815 Variables. This is only required when the tf.AutoTrackable object is not 816 maintained by the user (e.g. `from_saved_model`). 817 """ 818 super(TFLiteKerasModelConverterV2, self).__init__() 819 self._keras_model = keras_model 820 self._trackable_obj = trackable_obj 821 822 def _convert_as_saved_model(self): 823 """Converts a Keras model as a saved model. 824 825 Returns: 826 The converted data in serialized format. 827 """ 828 temp_dir = tempfile.mkdtemp() 829 try: 830 try: 831 self._keras_model.save(temp_dir, save_format="tf") 832 except Exception: # pylint: disable=broad-except 833 # When storing the given keras model to a saved model is failed, let's 834 # use original keras model conversion pipeline. 835 return None 836 self.saved_model_dir = temp_dir 837 self._saved_model_tags = set([_tag_constants.SERVING]) 838 self._saved_model_exported_names = [ 839 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 840 ] 841 self._parse_saved_model_args() 842 if self.saved_model_dir: 843 graph = _ops.Graph() 844 saved_model = _loader_impl.SavedModelLoader(self.saved_model_dir) 845 saved_model.load_graph(graph, tags=self._saved_model_tags) 846 meta_graph = saved_model.get_meta_graph_def_from_tags( 847 self._saved_model_tags) 848 signature_def = meta_graph.signature_def[ 849 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 850 input_tensors = [ 851 graph.get_tensor_by_name(signature_def.inputs[key].name) 852 for key in signature_def.inputs 853 ] 854 output_tensors = [ 855 graph.get_tensor_by_name(signature_def.outputs[key].name) 856 for key in signature_def.outputs 857 ] 858 self._trackable_obj = _load(self.saved_model_dir, 859 self._saved_model_tags) 860 return super(TFLiteKerasModelConverterV2, 861 self).convert(meta_graph.graph_def, input_tensors, 862 output_tensors) 863 finally: 864 shutil.rmtree(temp_dir, True) 865 866 def convert(self): 867 """Converts a keras model based on instance variables. 868 869 Returns: 870 The converted data in serialized format. 871 872 Raises: 873 ValueError: 874 Multiple concrete functions are specified. 875 Input shape is not specified. 876 Invalid quantization parameters. 877 """ 878 saved_model_convert_result = self._convert_as_saved_model() 879 if saved_model_convert_result: 880 return saved_model_convert_result 881 882 input_signature = None 883 # If the model's call is not a `tf.function`, then we need to first get its 884 # input signature from `model_input_signature` method. We can't directly 885 # call `trace_model_call` because otherwise the batch dimension is set 886 # to None. 887 # Once we have better support for dynamic shapes, we can remove this. 888 if not isinstance(self._keras_model.call, _def_function.Function): 889 # Pass `keep_original_batch_size=True` will ensure that we get an input 890 # signature including the batch dimension specified by the user. 891 # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF 892 input_signature = _model_input_signature( 893 self._keras_model, keep_original_batch_size=True) 894 895 # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF 896 func = _trace_model_call(self._keras_model, input_signature) 897 concrete_func = func.get_concrete_function() 898 self._funcs = [concrete_func] 899 900 frozen_func, graph_def = ( 901 _convert_to_constants.convert_variables_to_constants_v2_as_graph( 902 self._funcs[0], lower_control_flow=False)) 903 904 input_tensors = [ 905 tensor for tensor in frozen_func.inputs 906 if tensor.dtype != _dtypes.resource 907 ] 908 output_tensors = frozen_func.outputs 909 910 # Run a Grappler pass. 911 grappler_config = self._grappler_config() 912 # Skip running grappler when there are no optimizers to run. If not, 913 # grappler will run with the default optimizer set and it will lead to 914 # causing an unexpected behavior. 915 if grappler_config.graph_options.rewrite_options.optimizers: 916 graph_def = _run_graph_optimizations( 917 graph_def, 918 input_tensors, 919 output_tensors, 920 config=grappler_config, 921 graph=frozen_func.graph) 922 923 return super(TFLiteKerasModelConverterV2, 924 self).convert(graph_def, input_tensors, output_tensors) 925 926 927class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2): 928 """Converts the given frozen graph into TensorFlow Lite model.""" 929 930 def __init__(self, funcs, trackable_obj=None): 931 """Constructor for TFLiteConverter. 932 933 Args: 934 funcs: List of TensorFlow ConcreteFunctions. The list should not contain 935 duplicate elements. 936 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 937 reference to this object needs to be maintained so that Variables do not 938 get garbage collected since functions have a weak reference to 939 Variables. This is only required when the tf.AutoTrackable object is not 940 maintained by the user (e.g. `from_saved_model`). 941 """ 942 super(TFLiteFrozenGraphConverterV2, self).__init__() 943 self._funcs = funcs 944 self._trackable_obj = trackable_obj 945 946 def convert(self): 947 """Converts a TensorFlow GraphDef based on instance variables. 948 949 Returns: 950 The converted data in serialized format. 951 952 Raises: 953 ValueError: 954 No concrete functions is specified. 955 Multiple concrete functions are specified. 956 Input shape is not specified. 957 Invalid quantization parameters. 958 """ 959 # TODO(b/130297984): Add support for converting multiple function. 960 961 if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test 962 raise ValueError("No ConcreteFunction is specified.") 963 964 if len(self._funcs) > 1: 965 raise ValueError("This converter can only convert a single " 966 "ConcreteFunction. Converting multiple functions is " 967 "under development.") 968 969 frozen_func, graph_def = ( 970 _convert_to_constants.convert_variables_to_constants_v2_as_graph( 971 self._funcs[0], lower_control_flow=False)) 972 973 input_tensors = [ 974 tensor for tensor in frozen_func.inputs 975 if tensor.dtype != _dtypes.resource 976 ] 977 output_tensors = frozen_func.outputs 978 979 # Run a Grappler pass. 980 grappler_config = self._grappler_config() 981 # Skip running grappler when there are no optimizers to run. If not, 982 # grappler will run with the default optimizer set and it will lead to 983 # causing an unexpected behavior. 984 if grappler_config.graph_options.rewrite_options.optimizers: 985 graph_def = _run_graph_optimizations( 986 graph_def, 987 input_tensors, 988 output_tensors, 989 config=grappler_config, 990 graph=frozen_func.graph) 991 992 return super(TFLiteFrozenGraphConverterV2, 993 self).convert(graph_def, input_tensors, output_tensors) 994 995 996@_tf_export("lite.TFLiteConverter", v1=[]) 997class TFLiteConverterV2(TFLiteFrozenGraphConverterV2): 998 """Converts a TensorFlow model into TensorFlow Lite model. 999 1000 Attributes: 1001 optimizations: Experimental flag, subject to change. Set of optimizations 1002 to apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a 1003 set of values of type `tf.lite.Optimize`) 1004 representative_dataset: A generator function used for integer quantization 1005 where each generated sample has the same order, type and shape as the 1006 inputs to the model. Usually, this is a small subset of a few hundred 1007 samples randomly chosen, in no particular order, from the training or 1008 evaluation dataset. This is an optional attribute, but required for full 1009 integer quantization, i.e, if `tf.int8` is the only supported type in 1010 `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`. 1011 (default None) 1012 target_spec: Experimental flag, subject to change. Specifications of target 1013 device, including supported ops set, supported types and a set of user's 1014 defined TensorFlow operators required in the TensorFlow Lite runtime. 1015 Refer to `tf.lite.TargetSpec`. 1016 inference_input_type: Data type of the input layer. Note that integer types 1017 (tf.int8 and tf.uint8) are currently only supported for post training 1018 integer quantization and quantization aware training. (default tf.float32, 1019 must be in {tf.float32, tf.int8, tf.uint8}) 1020 inference_output_type: Data type of the output layer. Note that integer 1021 types (tf.int8 and tf.uint8) are currently only supported for post 1022 training integer quantization and quantization aware training. (default 1023 tf.float32, must be in {tf.float32, tf.int8, tf.uint8}) 1024 allow_custom_ops: Boolean indicating whether to allow custom operations. 1025 When False, any unknown operation is an error. When True, custom ops are 1026 created for any op that is unknown. The developer needs to provide these 1027 to the TensorFlow Lite runtime with a custom resolver. (default False) 1028 experimental_new_converter: Experimental flag, subject to change. Enables 1029 MLIR-based conversion instead of TOCO conversion. (default True) 1030 experimental_new_quantizer: Experimental flag, subject to change. Enables 1031 MLIR-based quantization conversion instead of Flatbuffer-based conversion. 1032 (default False) 1033 1034 Example usage: 1035 1036 ```python 1037 # Converting a SavedModel to a TensorFlow Lite model. 1038 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) 1039 tflite_model = converter.convert() 1040 1041 # Converting a tf.Keras model to a TensorFlow Lite model. 1042 converter = tf.lite.TFLiteConverter.from_keras_model(model) 1043 tflite_model = converter.convert() 1044 1045 # Converting ConcreteFunctions to a TensorFlow Lite model. 1046 converter = tf.lite.TFLiteConverter.from_concrete_functions([func]) 1047 tflite_model = converter.convert() 1048 ``` 1049 """ 1050 1051 # pylint: disable=useless-super-delegation 1052 def __init__(self, funcs, trackable_obj=None): 1053 """Constructor for TFLiteConverter. 1054 1055 Args: 1056 funcs: List of TensorFlow ConcreteFunctions. The list should not contain 1057 duplicate elements. 1058 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 1059 reference to this object needs to be maintained so that Variables do not 1060 get garbage collected since functions have a weak reference to 1061 Variables. This is only required when the tf.AutoTrackable object is not 1062 maintained by the user (e.g. `from_saved_model`). 1063 """ 1064 super(TFLiteConverterV2, self).__init__(funcs, trackable_obj) 1065 1066 @classmethod 1067 def from_concrete_functions(cls, funcs): 1068 """Creates a TFLiteConverter object from ConcreteFunctions. 1069 1070 Args: 1071 funcs: List of TensorFlow ConcreteFunctions. The list should not contain 1072 duplicate elements. Currently converter can only convert a single 1073 ConcreteFunction. Converting multiple functions is under development. 1074 1075 Returns: 1076 TFLiteConverter object. 1077 1078 Raises: 1079 Invalid input type. 1080 """ 1081 for func in funcs: 1082 if not isinstance(func, _function.ConcreteFunction): 1083 message = "This function takes in a list of ConcreteFunction." 1084 if isinstance(func, _def_function.Function): 1085 message += (" To get the ConcreteFunction from a Function," 1086 " call get_concrete_function.") 1087 raise ValueError(message) 1088 return cls(funcs) 1089 1090 @classmethod 1091 def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None): 1092 """Creates a TFLiteConverter object from a SavedModel directory. 1093 1094 Args: 1095 saved_model_dir: SavedModel directory to convert. 1096 signature_keys: List of keys identifying SignatureDef containing inputs 1097 and outputs. Elements should not be duplicated. By default the 1098 `signatures` attribute of the MetaGraphdef is used. (default 1099 saved_model.signatures) 1100 tags: Set of tags identifying the MetaGraphDef within the SavedModel to 1101 analyze. All tags in the tag set must be present. (default 1102 {tf.saved_model.SERVING} or {'serve'}) 1103 1104 Returns: 1105 TFLiteConverter object. 1106 1107 Raises: 1108 Invalid signature keys. 1109 """ 1110 # When run without eager enabled, this will return the legacy 1111 # TFLiteConverter. 1112 if not context.executing_eagerly(): 1113 signature_key = None 1114 if signature_keys: 1115 if len(signature_keys) != 1: 1116 raise ValueError("Only support a single signature key.") 1117 else: 1118 signature_key = signature_keys[0] 1119 logging.warning("Invoking the TF1 implementation of TFLiteConverter " 1120 "because eager is disabled. Consider enabling eager.") 1121 return TFLiteConverter.from_saved_model( 1122 saved_model_dir, signature_key=signature_key, tag_set=tags) 1123 1124 # Ensures any graphs created in Eager mode are able to run. This is required 1125 # in order to create a tf.estimator.Exporter that exports a TFLite model. 1126 if tags is None: 1127 tags = set([_tag_constants.SERVING]) 1128 1129 with context.eager_mode(): 1130 saved_model = _load(saved_model_dir, tags) 1131 if not signature_keys: 1132 signature_keys = saved_model.signatures 1133 1134 if len(signature_keys) != 1: 1135 raise ValueError("Only support a single signature key.") 1136 1137 funcs = [] 1138 for key in signature_keys: 1139 if key not in saved_model.signatures: 1140 raise ValueError("Invalid signature key '{}' found. Valid keys are " 1141 "'{}'.".format(key, ",".join(saved_model.signatures))) 1142 funcs.append(saved_model.signatures[key]) 1143 1144 saved_model_converter = TFLiteSavedModelConverterV2(saved_model_dir, tags, 1145 signature_keys, 1146 saved_model) 1147 if saved_model_converter.saved_model_dir: 1148 return saved_model_converter 1149 1150 return cls(funcs, saved_model) 1151 1152 @classmethod 1153 def from_keras_model(cls, model): 1154 """Creates a TFLiteConverter object from a Keras model. 1155 1156 Args: 1157 model: tf.Keras.Model 1158 1159 Returns: 1160 TFLiteConverter object. 1161 """ 1162 return TFLiteKerasModelConverterV2(model) 1163 1164 # pylint: disable=useless-super-delegation 1165 def convert(self): 1166 """Converts a TensorFlow GraphDef based on instance variables. 1167 1168 Returns: 1169 The converted data in serialized format. 1170 1171 Raises: 1172 ValueError: 1173 No concrete functions is specified. 1174 Multiple concrete functions are specified. 1175 Input shape is not specified. 1176 Invalid quantization parameters. 1177 """ 1178 return super(TFLiteConverterV2, self).convert() 1179 1180 1181class TFLiteConverterBaseV1(TFLiteConverterBase): 1182 """Converter subclass to share functionality between V1 converters.""" 1183 1184 def __init__(self, experimental_debug_info_func): 1185 """Constructor for TFLiteConverter. 1186 1187 Args: 1188 experimental_debug_info_func: An experimental function to retrieve the 1189 graph debug info for a set of nodes from the `graph_def`. 1190 """ 1191 super(TFLiteConverterBaseV1, self).__init__() 1192 self.inference_type = _dtypes.float32 1193 self.inference_input_type = None 1194 self.inference_output_type = None 1195 self.output_format = constants.TFLITE 1196 self.quantized_input_stats = {} 1197 self.default_ranges_stats = None 1198 self.drop_control_dependency = True 1199 self.reorder_across_fake_quant = False 1200 self.change_concat_input_ranges = False 1201 self.dump_graphviz_dir = None 1202 self.dump_graphviz_video = False 1203 self.conversion_summary_dir = None 1204 self._debug_info_func = experimental_debug_info_func 1205 self._custom_opdefs = None 1206 1207 def __setattr__(self, name, value): 1208 if name == "post_training_quantize": 1209 warnings.warn("Property %s is deprecated, " 1210 "please use optimizations=[Optimize.DEFAULT]" 1211 " instead." % name) 1212 if value: 1213 self.optimizations = [Optimize.DEFAULT] 1214 else: 1215 self.optimizations = [] 1216 return 1217 if name == "target_ops": 1218 warnings.warn("Property %s is deprecated, please use " 1219 "target_spec.supported_ops instead." % name) 1220 self.target_spec.supported_ops = value 1221 return 1222 object.__setattr__(self, name, value) 1223 1224 def __getattribute__(self, name): 1225 if name == "post_training_quantize": 1226 warnings.warn("Property %s is deprecated, " 1227 "please use optimizations=[Optimize.DEFAULT]" 1228 " instead." % name) 1229 return Optimize.DEFAULT in set(self.optimizations) 1230 if name == "target_ops": 1231 warnings.warn("Property %s is deprecated, please use " 1232 "target_spec.supported_ops instead." % name) 1233 return self.target_spec.supported_ops 1234 return object.__getattribute__(self, name) 1235 1236 def _validate_quantized_input_stats(self, converter_kwargs, calibrate): 1237 """Ensure the `quantized_input_stats` flag is provided if required.""" 1238 1239 quantized_types = frozenset({_dtypes.int8, _dtypes.uint8}) 1240 1241 requires_quantized_input_stats = ( 1242 (converter_kwargs["inference_type"] in quantized_types or 1243 converter_kwargs["inference_input_type"] in quantized_types) and 1244 not calibrate) 1245 1246 if (requires_quantized_input_stats and 1247 not converter_kwargs["quantized_input_stats"]): 1248 raise ValueError( 1249 "The `quantized_input_stats` flag must be defined when either " 1250 "`inference_type` flag or `inference_input_type` flag is set to " 1251 "tf.int8 or tf.uint8. Currently, `inference_type={}` and " 1252 "`inference_input_type={}`.".format( 1253 _get_tf_type_name(converter_kwargs["inference_type"]), 1254 _get_tf_type_name(converter_kwargs["inference_input_type"]))) 1255 1256 def convert(self): 1257 """Converts a TensorFlow GraphDef based on instance variables. 1258 1259 Returns: 1260 The converted data in serialized format. Either a TFLite Flatbuffer or a 1261 Graphviz graph depending on value in `output_format`. 1262 1263 Raises: 1264 ValueError: 1265 Input shape is not specified. 1266 None value for dimension in input_tensor. 1267 """ 1268 quant_mode = QuantizationMode(self.optimizations, self.target_spec, 1269 self.representative_dataset, self._graph_def) 1270 1271 if (not self._is_unknown_shapes_allowed() and self._has_valid_tensors()): 1272 # Checks dimensions in input tensor. 1273 for tensor in self._input_tensors: 1274 shape = tensor.shape 1275 if not shape: 1276 raise ValueError("Provide an input shape for input array " 1277 "'{0}'.".format(_get_tensor_name(tensor))) 1278 # Note that shape_list might be empty for scalar shapes. 1279 shape_list = shape.as_list() 1280 if None in shape_list[1:]: 1281 raise ValueError( 1282 "None is only supported in the 1st dimension. Tensor '{0}' has " 1283 "invalid shape '{1}'.".format( 1284 _get_tensor_name(tensor), shape_list)) 1285 elif shape_list and shape_list[0] is None: 1286 self._set_batch_size(batch_size=1) 1287 1288 # Get quantization stats. Ensures there is one stat per name if the stats 1289 # are specified. 1290 if self.quantized_input_stats: 1291 quantized_stats = [] 1292 invalid_stats = [] 1293 for name in self.get_input_arrays(): 1294 if name in self.quantized_input_stats: 1295 quantized_stats.append(self.quantized_input_stats[name]) 1296 else: 1297 invalid_stats.append(name) 1298 1299 if invalid_stats: 1300 raise ValueError("Quantization input stats are not available for input " 1301 "tensors '{0}'.".format(",".join(invalid_stats))) 1302 else: 1303 quantized_stats = None 1304 1305 optimized_graph = self._graph_def 1306 if not self.saved_model_dir: 1307 # Disable grappler constant folding if there are training quant ops. 1308 if not quant_mode.contains_training_quant_op(): 1309 try: 1310 # TODO(b/150163103): Merge `disabling lower using switch merge' calls. 1311 # Grappler will also try to lower while loop into switch merge 1312 # representation which is undesired for Ophints, so we simply remove 1313 # those attributes to prevent Grappler from doing so. 1314 graph_def = _convert_to_constants.disable_lower_using_switch_merge( 1315 optimized_graph) 1316 # Run function inlining optimization to ensure any models generated 1317 # through the from_frozen_graph path have been inlined. 1318 optimized_graph = _run_graph_optimizations( 1319 graph_def, 1320 self._input_tensors, 1321 self._output_tensors, 1322 config=self._grappler_config(["function"])) 1323 except Exception: # pylint: disable=broad-except 1324 optimized_graph = self._graph_def 1325 1326 self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph) 1327 1328 converter_kwargs = self._get_base_converter_args() 1329 converter_kwargs.update( 1330 quant_mode.converter_flags(self.inference_type, 1331 self.inference_input_type)) 1332 converter_kwargs.update({ 1333 "output_format": self.output_format, 1334 "quantized_input_stats": quantized_stats, 1335 "default_ranges_stats": self.default_ranges_stats, 1336 "drop_control_dependency": self.drop_control_dependency, 1337 "reorder_across_fake_quant": self.reorder_across_fake_quant, 1338 "change_concat_input_ranges": self.change_concat_input_ranges, 1339 "dump_graphviz_dir": self.dump_graphviz_dir, 1340 "dump_graphviz_video": self.dump_graphviz_video, 1341 "conversion_summary_dir": self.conversion_summary_dir, 1342 "custom_opdefs": self._custom_opdefs, 1343 }) 1344 1345 if not self.experimental_new_converter: 1346 logging.warning( 1347 "Please consider switching to the new converter by setting " 1348 "experimental_new_converter=True. " 1349 "The old converter (TOCO) is deprecated.") 1350 else: 1351 logging.info("Using experimental converter: If you encountered a problem " 1352 "please file a bug. You can opt-out " 1353 "by setting experimental_new_converter=False") 1354 1355 if not self.experimental_new_converter: 1356 calibrate_quantize, flags = quant_mode.quantizer_flags( 1357 self.inference_input_type, self.inference_output_type) 1358 else: 1359 calibrate_quantize, flags = quant_mode.quantizer_flags() 1360 1361 self._validate_quantized_input_stats(converter_kwargs, calibrate_quantize) 1362 self._validate_experimental_new_quantizer_flag() 1363 1364 # Converts model. 1365 if self._has_valid_tensors(): 1366 result = _toco_convert_impl( 1367 input_data=optimized_graph, 1368 input_tensors=self._input_tensors, 1369 output_tensors=self._output_tensors, 1370 **converter_kwargs) 1371 else: 1372 result = _toco_convert_graph_def( 1373 input_data=optimized_graph, 1374 input_arrays_with_shape=self._input_arrays_with_shape, 1375 output_arrays=self._output_arrays, 1376 **converter_kwargs) 1377 1378 if calibrate_quantize: 1379 result = self._calibrate_quantize_model(result, **flags) 1380 1381 if self.experimental_new_converter or self.experimental_new_quantizer: 1382 flags_modify_model_io_type = quant_mode.flags_modify_model_io_type( 1383 self.inference_input_type, self.inference_output_type) 1384 if flags_modify_model_io_type: 1385 result = _modify_model_io_type(result, **flags_modify_model_io_type) 1386 1387 if self._sparsify_model(): 1388 result = _mlir_sparsify(result) 1389 1390 return result 1391 1392 def get_input_arrays(self): 1393 """Returns a list of the names of the input tensors. 1394 1395 Returns: 1396 List of strings. 1397 """ 1398 if self._has_valid_tensors(): 1399 return [_get_tensor_name(tensor) for tensor in self._input_tensors] 1400 else: 1401 return [name for name, _ in self._input_arrays_with_shape] 1402 1403 def _has_valid_tensors(self): 1404 """Checks if the input and output tensors have been initialized. 1405 1406 Returns: 1407 Bool. 1408 """ 1409 return self._input_tensors is not None and self._output_tensors 1410 1411 def _set_batch_size(self, batch_size): 1412 """Sets the first dimension of the input tensor to `batch_size`. 1413 1414 Args: 1415 batch_size: Batch size for the model. Replaces the first dimension of an 1416 input size array if undefined. (default 1) 1417 1418 Raises: 1419 ValueError: input_tensor is not defined. 1420 """ 1421 if not self._has_valid_tensors(): 1422 raise ValueError("The batch size cannot be set for this model. Please " 1423 "use input_shapes parameter.") 1424 1425 for tensor in self._input_tensors: 1426 shape = tensor.shape.as_list() 1427 if shape[0] is None: 1428 shape[0] = batch_size 1429 tensor.set_shape(shape) 1430 1431 def _is_unknown_shapes_allowed(self): 1432 # Ophint Converted nodes will need the shapes to be known. 1433 if _is_ophint_converted(self._graph_def): 1434 return False 1435 1436 if not super(TFLiteConverterBaseV1, self)._is_unknown_shapes_allowed(): 1437 return False 1438 1439 # `conversion_summary_dir` calls TOCO. Unknown shapes are only supported by 1440 # the MLIR converter. 1441 if self.conversion_summary_dir: 1442 logging.warning( 1443 "`conversion_summary_dir` does not work with unknown shapes. " 1444 "Graphs with unknown shapes might be different than when this flag " 1445 "is disabled.") 1446 return False 1447 return True 1448 1449 1450class TFLiteSavedModelConverter(TFLiteConverterBaseV1): 1451 """Converts the given SavedModel into TensorFlow Lite model. 1452 1453 Attributes: 1454 saved_model_dir: Directory of the SavedModel. 1455 """ 1456 1457 def __init__(self, 1458 saved_model_dir, 1459 saved_model_tags, 1460 saved_model_exported_names, 1461 experimental_debug_info_func=None): 1462 """Constructor for TFLiteConverter. 1463 1464 Args: 1465 saved_model_dir: Directory of the SavedModel. 1466 saved_model_tags: Set of tags identifying the MetaGraphDef within the 1467 SavedModel to analyze. All tags in the tag set must be present. (default 1468 {tf.saved_model.SERVING}). 1469 saved_model_exported_names: Names to be exported when the saved model 1470 import path is on. 1471 experimental_debug_info_func: An experimental function to retrieve the 1472 graph debug info for a set of nodes from the `graph_def`. 1473 1474 Raises: 1475 ValueError: Invalid arguments. 1476 """ 1477 super(TFLiteSavedModelConverter, 1478 self).__init__(experimental_debug_info_func) 1479 self.saved_model_dir = saved_model_dir 1480 self._saved_model_tags = saved_model_tags 1481 self._saved_model_exported_names = saved_model_exported_names 1482 1483 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 1484 1485 if len(self._saved_model_exported_names) != 1: 1486 raise ValueError("Only support a single signature key.") 1487 1488 signature_key = self._saved_model_exported_names[0] 1489 1490 result = _freeze_saved_model(self.saved_model_dir, None, None, None, 1491 self._saved_model_tags, signature_key) 1492 self._graph_def = result[0] 1493 self._input_tensors = result[1] 1494 self._output_tensors = result[2] 1495 self._parse_saved_model_args() 1496 1497 1498class TFLiteKerasModelConverter(TFLiteConverterBaseV1): 1499 """Converts the given SavedModel into TensorFlow Lite model.""" 1500 1501 def __init__(self, 1502 model_file, 1503 input_arrays=None, 1504 input_shapes=None, 1505 output_arrays=None, 1506 custom_objects=None): 1507 """Constructor for TFLiteConverter. 1508 1509 Args: 1510 model_file: Full filepath of HDF5 file containing the tf.keras model. 1511 input_arrays: List of input tensors to freeze graph with. Uses input 1512 arrays from SignatureDef when none are provided. (default None) 1513 input_shapes: Dict of strings representing input tensor names to list of 1514 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 1515 Automatically determined when input shapes is None (e.g., {"foo" : 1516 None}). (default None) 1517 output_arrays: List of output tensors to freeze graph with. Uses output 1518 arrays from SignatureDef when none are provided. (default None) 1519 custom_objects: Dict mapping names (strings) to custom classes or 1520 functions to be considered during model deserialization. (default None) 1521 1522 Raises: 1523 ValueError: Invalid arguments. 1524 """ 1525 super(TFLiteKerasModelConverter, 1526 self).__init__(experimental_debug_info_func=None) 1527 # Handles Keras when Eager mode is enabled. 1528 if context.executing_eagerly(): 1529 if input_arrays or output_arrays: 1530 raise ValueError("`input_arrays` and `output_arrays` are unsupported " 1531 "with Eager mode. If your model requires any of these " 1532 "parameters, please use disable_eager_execution().") 1533 1534 keras_model = keras_deps.get_load_model_function()( 1535 model_file, custom_objects) 1536 function = _trace_model_call(keras_model) 1537 concrete_func = function.get_concrete_function() 1538 1539 frozen_func = _convert_to_constants.convert_variables_to_constants_v2( 1540 concrete_func, lower_control_flow=False) 1541 _set_tensor_shapes(frozen_func.inputs, input_shapes) 1542 self._keras_model = keras_model 1543 self._graph_def = frozen_func.graph.as_graph_def() 1544 self._input_tensors = frozen_func.inputs 1545 self._output_tensors = frozen_func.outputs 1546 self._debug_info_func = _build_debug_info_func(frozen_func.graph) 1547 return 1548 1549 # Handles Keras when Eager mode is disabled. 1550 keras_deps.get_clear_session_function()() 1551 keras_model = keras_deps.get_load_model_function()( 1552 model_file, custom_objects) 1553 sess = keras_deps.get_get_session_function()() 1554 1555 # Get input and output tensors. 1556 if input_arrays: 1557 input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays) 1558 else: 1559 input_tensors = keras_model.inputs 1560 1561 if output_arrays: 1562 output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays) 1563 else: 1564 output_tensors = keras_model.outputs 1565 _set_tensor_shapes(input_tensors, input_shapes) 1566 1567 graph_def = _freeze_graph(sess, input_tensors, output_tensors) 1568 self._keras_model = keras_model 1569 self._graph_def = graph_def 1570 self._input_tensors = input_tensors 1571 self._output_tensors = output_tensors 1572 self._debug_info_func = _build_debug_info_func(sess.graph) 1573 1574 def _convert_as_saved_model(self): 1575 """Converts a Keras model as a saved model. 1576 1577 Returns: 1578 The converted data in serialized format. 1579 """ 1580 temp_dir = tempfile.mkdtemp() 1581 try: 1582 try: 1583 self._keras_model.save(temp_dir, save_format="tf") 1584 except Exception: # pylint: disable=broad-except 1585 # When storing the given keras model to a saved model is failed, let's 1586 # use original keras model conversion pipeline. 1587 return None 1588 tag_set = set([_tag_constants.SERVING]) 1589 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 1590 result = _freeze_saved_model(temp_dir, None, None, None, tag_set, 1591 signature_key) 1592 1593 self.saved_model_dir = temp_dir 1594 self._saved_model_tags = tag_set 1595 self._saved_model_exported_names = [signature_key] 1596 self._parse_saved_model_args() 1597 if self.saved_model_dir: 1598 self._graph_def = result[0] 1599 self._input_tensors = result[1] 1600 self._output_tensors = result[2] 1601 self._debug_info_func = _build_debug_info_func(result[3]) 1602 return super(TFLiteKerasModelConverter, self).convert() 1603 finally: 1604 shutil.rmtree(temp_dir, True) 1605 1606 def convert(self): 1607 """Converts a Keras model based on instance variables. 1608 1609 Returns: 1610 The converted data in serialized format. Either a TFLite Flatbuffer or a 1611 Graphviz graph depending on value in `output_format`. 1612 1613 Raises: 1614 ValueError: 1615 Input shape is not specified. 1616 None value for dimension in input_tensor. 1617 """ 1618 saved_model_convert_result = self._convert_as_saved_model() 1619 if saved_model_convert_result: 1620 return saved_model_convert_result 1621 1622 return super(TFLiteKerasModelConverter, self).convert() 1623 1624 1625class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1): 1626 """Converts the given frozen graph def into TensorFlow Lite model.""" 1627 1628 def __init__(self, 1629 graph_def, 1630 input_tensors, 1631 output_tensors, 1632 input_arrays_with_shape=None, 1633 output_arrays=None, 1634 experimental_debug_info_func=None): 1635 """Constructor for TFLiteConverter. 1636 1637 Args: 1638 graph_def: Frozen TensorFlow GraphDef. 1639 input_tensors: List of input tensors. Type and shape are computed using 1640 `foo.shape` and `foo.dtype`. 1641 output_tensors: List of output tensors (only .name is used from this). 1642 input_arrays_with_shape: Tuple of strings representing input tensor names 1643 and list of integers representing input shapes 1644 (e.g., [("foo", [1, 16, 16, 3])]). Use only when graph cannot be loaded 1645 into TensorFlow and when `input_tensors` and `output_tensors` are 1646 None. (default None) 1647 output_arrays: List of output tensors to freeze graph with. Use only when 1648 graph cannot be loaded into TensorFlow and when `input_tensors` and 1649 `output_tensors` are None. (default None) 1650 experimental_debug_info_func: An experimental function to retrieve the 1651 graph debug info for a set of nodes from the `graph_def`. 1652 1653 Raises: 1654 ValueError: Invalid arguments. 1655 """ 1656 super(TFLiteFrozenGraphConverter, 1657 self).__init__(experimental_debug_info_func) 1658 self._graph_def = graph_def 1659 self._input_tensors = input_tensors 1660 self._output_tensors = output_tensors 1661 1662 # Attributes are used by models that cannot be loaded into TensorFlow. 1663 if not self._has_valid_tensors(): 1664 if not input_arrays_with_shape or not output_arrays: 1665 raise ValueError( 1666 "If input_tensors and output_tensors are None, both " 1667 "input_arrays_with_shape and output_arrays must be defined.") 1668 self._input_arrays_with_shape = input_arrays_with_shape 1669 self._output_arrays = output_arrays 1670 1671 if input_tensors is not None and input_arrays_with_shape is not None: 1672 logging.warning("input_arrays_with_shape will be ignored when both the " 1673 "given input_tensors and input_arrays_with_shape are not " 1674 "None.") 1675 1676 if output_tensors is not None and output_arrays is not None: 1677 logging.warning("output_arrays will be ignored when both the given " 1678 "output_tensors and output_arrays are not None.") 1679 1680 1681@_tf_export(v1=["lite.TFLiteConverter"]) 1682class TFLiteConverter(TFLiteFrozenGraphConverter): 1683 """Convert a TensorFlow model into `output_format`. 1684 1685 This is used to convert from a TensorFlow GraphDef, SavedModel or tf.keras 1686 model into either a TFLite FlatBuffer or graph visualization. 1687 1688 Attributes: 1689 optimizations: Experimental flag, subject to change. Set of optimizations to 1690 apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a 1691 set of values of type `tf.lite.Optimize`) 1692 representative_dataset: A generator function used for integer quantization 1693 where each generated sample has the same order, type and shape as the 1694 inputs to the model. Usually, this is a small subset of a few hundred 1695 samples randomly chosen, in no particular order, from the training or 1696 evaluation dataset. This is an optional attribute, but required for full 1697 integer quantization, i.e, if `tf.int8` is the only supported type in 1698 `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`. 1699 (default None) 1700 target_spec: Experimental flag, subject to change. Specifications of target 1701 device, including supported ops set, supported types and a set of user's 1702 defined TensorFlow operators required in the TensorFlow Lite runtime. 1703 Refer to `tf.lite.TargetSpec`. 1704 inference_type: Data type of numeric arrays, excluding the input layer. 1705 (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8}) 1706 inference_input_type: Data type of the numeric arrays in the input layer. If 1707 `inference_input_type` is in {tf.int8, tf.uint8}, then 1708 `quantized_input_stats` must be provided. (default is the value assigned 1709 to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8}) 1710 inference_output_type: Data type of the numeric arrays in the output layer. 1711 (default is the value assigned to `inference_type`, must be in 1712 {tf.float32, tf.int8, tf.uint8}) 1713 quantized_input_stats: Map of input tensor names to a tuple of floats 1714 representing the mean and standard deviation of the training data. 1715 (e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8 1716 or tf.uint8. (default None) 1717 default_ranges_stats: Tuple of integers (min, max) representing range values 1718 for all numeric arrays without a specified range. Intended for 1719 experimenting with quantization via "dummy quantization". (default None) 1720 allow_custom_ops: Boolean indicating whether to allow custom operations. 1721 When False any unknown operation is an error. When True, custom ops are 1722 created for any op that is unknown. The developer will need to provide 1723 these to the TensorFlow Lite runtime with a custom resolver. (default 1724 False) 1725 drop_control_dependency: Boolean indicating whether to drop control 1726 dependencies silently. This is due to TFLite not supporting control 1727 dependencies. (default True) 1728 reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant 1729 nodes in unexpected locations. Used when the location of the FakeQuant 1730 nodes is preventing graph transformations necessary to convert the graph. 1731 Results in a graph that differs from the quantized training graph, 1732 potentially causing differing arithmetic behavior. (default False) 1733 change_concat_input_ranges: Boolean to change behavior of min/max ranges for 1734 inputs and outputs of the concat operator for quantized models. Changes 1735 the ranges of concat operator overlap when true. (default False) 1736 output_format: Output file format. (default 1737 tf.compat.v1.lite.constants.TFLITE, must be in 1738 {tf.compat.v1.lite.constants.TFLITE, 1739 tf.compat.v1.lite.constants.GRAPHVIZ_DOT}) 1740 dump_graphviz_dir: Full filepath of folder to dump the graphs at various 1741 stages of processing GraphViz .dot files. Preferred over 1742 `output_format=tf.compat.v1.lite.constants.GRAPHVIZ_DOT` in order to keep 1743 the requirements of the output file. (default None) 1744 dump_graphviz_video: Boolean indicating whether to dump the GraphViz .dot 1745 files after every graph transformation. Requires the `dump_graphviz_dir` 1746 flag to be specified. (default False) 1747 conversion_summary_dir: Full path of the directory to store conversion logs. 1748 (default None) 1749 target_ops: Deprecated. Please use `target_spec.supported_ops` instead. 1750 post_training_quantize: Deprecated. Please use `optimizations` instead and 1751 set it to `{tf.lite.Optimize.DEFAULT}`. (default False) 1752 experimental_new_converter: Experimental flag, subject to change. Enables 1753 MLIR-based conversion instead of TOCO conversion. (default True) 1754 experimental_new_quantizer: Experimental flag, subject to change. Enables 1755 MLIR-based quantization conversion instead of Flatbuffer-based conversion. 1756 (default False) 1757 1758 Example usage: 1759 1760 ```python 1761 # Converting a GraphDef from session. 1762 converter = tf.compat.v1.lite.TFLiteConverter.from_session( 1763 sess, in_tensors, out_tensors) 1764 tflite_model = converter.convert() 1765 open("converted_model.tflite", "wb").write(tflite_model) 1766 1767 # Converting a GraphDef from file. 1768 converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( 1769 graph_def_file, input_arrays, output_arrays) 1770 tflite_model = converter.convert() 1771 open("converted_model.tflite", "wb").write(tflite_model) 1772 1773 # Converting a SavedModel. 1774 converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model( 1775 saved_model_dir) 1776 tflite_model = converter.convert() 1777 open("converted_model.tflite", "wb").write(tflite_model) 1778 1779 # Converting a tf.keras model. 1780 converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file( 1781 keras_model) 1782 tflite_model = converter.convert() 1783 open("converted_model.tflite", "wb").write(tflite_model) 1784 ``` 1785 """ 1786 1787 # pylint: disable=useless-super-delegation 1788 def __init__(self, 1789 graph_def, 1790 input_tensors, 1791 output_tensors, 1792 input_arrays_with_shape=None, 1793 output_arrays=None, 1794 experimental_debug_info_func=None): 1795 """Constructor for TFLiteConverter. 1796 1797 Args: 1798 graph_def: Frozen TensorFlow GraphDef. 1799 input_tensors: List of input tensors. Type and shape are computed using 1800 `foo.shape` and `foo.dtype`. 1801 output_tensors: List of output tensors (only .name is used from this). 1802 input_arrays_with_shape: Tuple of strings representing input tensor names 1803 and list of integers representing input shapes 1804 (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded 1805 into TensorFlow and when `input_tensors` and `output_tensors` are 1806 None. (default None) 1807 output_arrays: List of output tensors to freeze graph with. Use only when 1808 graph cannot be loaded into TensorFlow and when `input_tensors` and 1809 `output_tensors` are None. (default None) 1810 experimental_debug_info_func: An experimental function to retrieve the 1811 graph debug info for a set of nodes from the `graph_def`. 1812 1813 Raises: 1814 ValueError: Invalid arguments. 1815 """ 1816 super(TFLiteConverter, 1817 self).__init__(graph_def, input_tensors, output_tensors, 1818 input_arrays_with_shape, output_arrays, 1819 experimental_debug_info_func) 1820 1821 @classmethod 1822 def from_session(cls, sess, input_tensors, output_tensors): 1823 """Creates a TFLiteConverter class from a TensorFlow Session. 1824 1825 Args: 1826 sess: TensorFlow Session. 1827 input_tensors: List of input tensors. Type and shape are computed using 1828 `foo.shape` and `foo.dtype`. 1829 output_tensors: List of output tensors (only .name is used from this). 1830 1831 Returns: 1832 TFLiteConverter class. 1833 """ 1834 graph_def = _freeze_graph(sess, input_tensors, output_tensors) 1835 return cls( 1836 graph_def, 1837 input_tensors, 1838 output_tensors, 1839 experimental_debug_info_func=_build_debug_info_func(sess.graph)) 1840 1841 @classmethod 1842 def from_frozen_graph(cls, 1843 graph_def_file, 1844 input_arrays, 1845 output_arrays, 1846 input_shapes=None): 1847 """Creates a TFLiteConverter class from a file containing a frozen GraphDef. 1848 1849 Args: 1850 graph_def_file: Full filepath of file containing frozen GraphDef. 1851 input_arrays: List of input tensors to freeze graph with. 1852 output_arrays: List of output tensors to freeze graph with. 1853 input_shapes: Dict of strings representing input tensor names to list of 1854 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 1855 Automatically determined when input shapes is None (e.g., {"foo" : 1856 None}). (default None) 1857 1858 Returns: 1859 TFLiteConverter class. 1860 1861 Raises: 1862 IOError: 1863 File not found. 1864 Unable to parse input file. 1865 ValueError: 1866 The graph is not frozen. 1867 input_arrays or output_arrays contains an invalid tensor name. 1868 input_shapes is not correctly defined when required 1869 """ 1870 with _ops.Graph().as_default(): 1871 with _session.Session() as sess: 1872 # Read GraphDef from file. 1873 if not _file_io.file_exists(graph_def_file): 1874 raise IOError("File '{0}' does not exist.".format(graph_def_file)) 1875 with _file_io.FileIO(graph_def_file, "rb") as f: 1876 file_content = f.read() 1877 1878 try: 1879 graph_def = _graph_pb2.GraphDef() 1880 graph_def.ParseFromString(file_content) 1881 except (_text_format.ParseError, DecodeError): 1882 try: 1883 print("Ignore 'tcmalloc: large alloc' warnings.") 1884 1885 if not isinstance(file_content, str): 1886 if PY2: 1887 file_content = six.ensure_binary(file_content, "utf-8") 1888 else: 1889 file_content = six.ensure_text(file_content, "utf-8") 1890 graph_def = _graph_pb2.GraphDef() 1891 _text_format.Merge(file_content, graph_def) 1892 except (_text_format.ParseError, DecodeError): 1893 raise IOError( 1894 "Unable to parse input file '{}'.".format(graph_def_file)) 1895 1896 # Handles models with custom TFLite ops that cannot be resolved in 1897 # TensorFlow. 1898 load_model_in_session = True 1899 try: 1900 _import_graph_def(graph_def, name="") 1901 except _NotFoundError: 1902 load_model_in_session = False 1903 1904 if load_model_in_session: 1905 # Check if graph is frozen. 1906 if not _is_frozen_graph(sess): 1907 raise ValueError("Please freeze the graph using freeze_graph.py.") 1908 1909 # Get input and output tensors. 1910 input_tensors = _get_tensors_from_tensor_names( 1911 sess.graph, input_arrays) 1912 output_tensors = _get_tensors_from_tensor_names( 1913 sess.graph, output_arrays) 1914 _set_tensor_shapes(input_tensors, input_shapes) 1915 1916 return cls(sess.graph_def, input_tensors, output_tensors) 1917 else: 1918 if not input_shapes: 1919 raise ValueError("input_shapes must be defined for this model.") 1920 if set(input_arrays) != set(input_shapes.keys()): 1921 raise ValueError("input_shapes must contain a value for each item " 1922 "in input_array.") 1923 1924 input_arrays_with_shape = [ 1925 (name, input_shapes[name]) for name in input_arrays 1926 ] 1927 return cls( 1928 graph_def, 1929 input_tensors=None, 1930 output_tensors=None, 1931 input_arrays_with_shape=input_arrays_with_shape, 1932 output_arrays=output_arrays) 1933 1934 @classmethod 1935 def from_saved_model(cls, 1936 saved_model_dir, 1937 input_arrays=None, 1938 input_shapes=None, 1939 output_arrays=None, 1940 tag_set=None, 1941 signature_key=None): 1942 """Creates a TFLiteConverter class from a SavedModel. 1943 1944 Args: 1945 saved_model_dir: SavedModel directory to convert. 1946 input_arrays: List of input tensors to freeze graph with. Uses input 1947 arrays from SignatureDef when none are provided. (default None) 1948 input_shapes: Dict of strings representing input tensor names to list of 1949 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 1950 Automatically determined when input shapes is None (e.g., {"foo" : 1951 None}). (default None) 1952 output_arrays: List of output tensors to freeze graph with. Uses output 1953 arrays from SignatureDef when none are provided. (default None) 1954 tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to 1955 analyze. All tags in the tag set must be present. (default 1956 {tf.saved_model.SERVING}) 1957 signature_key: Key identifying SignatureDef containing inputs and outputs. 1958 (default tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY) 1959 1960 Returns: 1961 TFLiteConverter class. 1962 """ 1963 if tag_set is None: 1964 tag_set = set([_tag_constants.SERVING]) 1965 if signature_key is None: 1966 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 1967 1968 saved_model_converter = TFLiteSavedModelConverter(saved_model_dir, tag_set, 1969 [signature_key]) 1970 if saved_model_converter.saved_model_dir: 1971 return saved_model_converter 1972 1973 result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, 1974 output_arrays, tag_set, signature_key) 1975 1976 return cls( 1977 graph_def=result[0], 1978 input_tensors=result[1], 1979 output_tensors=result[2], 1980 experimental_debug_info_func=_build_debug_info_func(result[3])) 1981 1982 @classmethod 1983 def from_keras_model_file(cls, 1984 model_file, 1985 input_arrays=None, 1986 input_shapes=None, 1987 output_arrays=None, 1988 custom_objects=None): 1989 """Creates a TFLiteConverter class from a tf.keras model file. 1990 1991 Args: 1992 model_file: Full filepath of HDF5 file containing the tf.keras model. 1993 input_arrays: List of input tensors to freeze graph with. Uses input 1994 arrays from SignatureDef when none are provided. (default None) 1995 input_shapes: Dict of strings representing input tensor names to list of 1996 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 1997 Automatically determined when input shapes is None (e.g., {"foo" : 1998 None}). (default None) 1999 output_arrays: List of output tensors to freeze graph with. Uses output 2000 arrays from SignatureDef when none are provided. (default None) 2001 custom_objects: Dict mapping names (strings) to custom classes or 2002 functions to be considered during model deserialization. (default None) 2003 2004 Returns: 2005 TFLiteConverter class. 2006 """ 2007 return TFLiteKerasModelConverter(model_file, input_arrays, input_shapes, 2008 output_arrays, custom_objects) 2009 2010 # pylint: disable=useless-super-delegation 2011 def convert(self): 2012 """Converts a TensorFlow GraphDef based on instance variables. 2013 2014 Returns: 2015 The converted data in serialized format. Either a TFLite Flatbuffer or a 2016 Graphviz graph depending on value in `output_format`. 2017 2018 Raises: 2019 ValueError: 2020 Input shape is not specified. 2021 None value for dimension in input_tensor. 2022 """ 2023 return super(TFLiteConverter, self).convert() 2024 2025 2026@_tf_export(v1=["lite.TocoConverter"]) 2027class TocoConverter(object): 2028 """Convert a TensorFlow model into `output_format` using TOCO. 2029 2030 This class has been deprecated. Please use `lite.TFLiteConverter` instead. 2031 """ 2032 2033 @classmethod 2034 @_deprecation.deprecated(None, 2035 "Use `lite.TFLiteConverter.from_session` instead.") 2036 def from_session(cls, sess, input_tensors, output_tensors): 2037 """Creates a TocoConverter class from a TensorFlow Session.""" 2038 return TFLiteConverter.from_session(sess, input_tensors, output_tensors) 2039 2040 @classmethod 2041 @_deprecation.deprecated( 2042 None, "Use `lite.TFLiteConverter.from_frozen_graph` instead.") 2043 def from_frozen_graph(cls, 2044 graph_def_file, 2045 input_arrays, 2046 output_arrays, 2047 input_shapes=None): 2048 """Creates a TocoConverter class from a file containing a frozen graph.""" 2049 return TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, 2050 output_arrays, input_shapes) 2051 2052 @classmethod 2053 @_deprecation.deprecated( 2054 None, "Use `lite.TFLiteConverter.from_saved_model` instead.") 2055 def from_saved_model(cls, 2056 saved_model_dir, 2057 input_arrays=None, 2058 input_shapes=None, 2059 output_arrays=None, 2060 tag_set=None, 2061 signature_key=None): 2062 """Creates a TocoConverter class from a SavedModel.""" 2063 return TFLiteConverter.from_saved_model(saved_model_dir, input_arrays, 2064 input_shapes, output_arrays, 2065 tag_set, signature_key) 2066 2067 @classmethod 2068 @_deprecation.deprecated( 2069 None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.") 2070 def from_keras_model_file(cls, 2071 model_file, 2072 input_arrays=None, 2073 input_shapes=None, 2074 output_arrays=None): 2075 """Creates a TocoConverter class from a tf.keras model file.""" 2076 return TFLiteConverter.from_keras_model_file(model_file, input_arrays, 2077 input_shapes, output_arrays) 2078