1# Lint as: python2, python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Converts a frozen graph into a TFLite FlatBuffer.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import distutils.spawn 23import enum # pylint: disable=g-bad-import-order 24import os as _os 25import platform as _platform 26import subprocess as _subprocess 27import tempfile as _tempfile 28 29import six 30from six.moves import map 31 32from tensorflow.lite.python import lite_constants 33from tensorflow.lite.python import util 34from tensorflow.lite.python import wrap_toco 35from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2 36from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 37from tensorflow.lite.toco import types_pb2 as _types_pb2 38from tensorflow.python.framework import dtypes 39from tensorflow.python.framework import tensor_shape 40from tensorflow.python.platform import resource_loader as _resource_loader 41from tensorflow.python.util import deprecation 42from tensorflow.python.util.tf_export import tf_export as _tf_export 43 44_quantized_inference_types = [_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8] 45 46 47# If the `inference_type` or the `inference_input_type` is the quantized type 48# and it is not post training quantization, the input quantization stats is 49# required. 50def _requires_input_stats(toco_flags): 51 return ((toco_flags.inference_type in _quantized_inference_types or 52 toco_flags.inference_input_type in _quantized_inference_types) and 53 not toco_flags.post_training_quantize) 54 55# Find the toco_from_protos binary using the resource loader if using from 56# bazel, otherwise we are in a pip where console_scripts already has 57# the toco_from_protos tool. 58if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY: 59 _toco_from_proto_bin = "" 60else: 61 _toco_from_proto_bin = _resource_loader.get_path_to_datafile( 62 "../toco/python/toco_from_protos") 63 64if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin): 65 _toco_from_proto_bin = "toco_from_protos" 66 67 68def _try_convert_to_unicode(output): 69 if output is None: 70 return u"" 71 72 if isinstance(output, bytes): 73 try: 74 return six.ensure_text(output) 75 except UnicodeDecodeError: 76 pass 77 return output 78 79 80@_tf_export("lite.OpsSet") 81class OpsSet(enum.Enum): 82 """Enum class defining the sets of ops available to generate TFLite models. 83 84 WARNING: Experimental interface, subject to change. 85 """ 86 # Convert model using TensorFlow Lite builtin ops. 87 TFLITE_BUILTINS = "TFLITE_BUILTINS" 88 89 # Convert model using TensorFlow ops. Not all TensorFlow ops are available. 90 # WARNING: Experimental interface, subject to change. 91 SELECT_TF_OPS = "SELECT_TF_OPS" 92 93 # Convert model using only TensorFlow Lite quantized int8 operations. 94 # Specifying this will throw an error for operations that do not yet have 95 # quantized implementations. 96 TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8" 97 98 # Convert model using only TensorFlow Lite operations with quantized int8 99 # weights, int16 activations and int64 bias. 100 # Specifying this will throw an error for operations that do not yet have 101 # quantized implementations. 102 # This quantization mode may be used in models for super-resolution, 103 # audio signal processing or image de-noising. It improves accuracy 104 # significantly, but only slightly increases the model size. 105 # WARNING: These ops are currently experimental and have not yet been 106 # finalized. 107 # They are only compatible with CPU execution, and have not been optimized for 108 # production. 109 EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = \ 110 "EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8" 111 112 def __str__(self): 113 return str(self.value) 114 115 @staticmethod 116 def get_options(): 117 """Returns a list of OpsSet options as a list of strings.""" 118 return [str(option) for option in list(OpsSet)] 119 120 121class ConverterError(Exception): 122 """Raised when an error occurs during model conversion.""" 123 pass 124 125 126def mlir_quantize(input_data_str, 127 disable_per_channel=False, 128 fully_quantize=False, 129 inference_type=_types_pb2.INT8, 130 enable_numeric_verify=False): 131 """Quantize `input_data_str` with calibration results. 132 133 Args: 134 input_data_str: Input data in serialized form (e.g. a TFLITE model with 135 calibration results). 136 disable_per_channel: Bool indicating whether to do per-channel or per-tensor 137 quantization 138 fully_quantize: Bool indicating whether to fully quantize the model. Besides 139 model body, the input/output will be quantized as well. 140 inference_type: Data type for the activations. The default value is int8. 141 enable_numeric_verify: Experimental. Subject to change. Bool indicating 142 whether to add NumericVerify ops into the debug mode quantized model. 143 144 Returns: 145 Quantized model in serialized form (e.g. a TFLITE model) with floating-point 146 inputs and outputs. 147 """ 148 return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str, 149 disable_per_channel, 150 fully_quantize, 151 inference_type, 152 enable_numeric_verify) 153 154 155def mlir_sparsify(input_data_str): 156 """Sparsify `input_data_str` to encode sparse tensor with proper format. 157 158 Args: 159 input_data_str: Input data in serialized form (e.g. a TFLITE model). 160 161 Returns: 162 Sparsified model in serialized form (e.g. a TFLITE model). 163 """ 164 return wrap_toco.wrapped_experimental_mlir_sparsify(input_data_str) 165 166 167def register_custom_opdefs(custom_opdefs_list): 168 """Register the given custom opdefs to the TensorFlow global op registry. 169 170 Args: 171 custom_opdefs_list: String representing the custom ops OpDefs that are 172 included in the GraphDef. 173 174 Returns: 175 True if the registration is successfully completed. 176 """ 177 return wrap_toco.wrapped_register_custom_opdefs(custom_opdefs_list) 178 179 180def toco_convert_protos(model_flags_str, 181 toco_flags_str, 182 input_data_str, 183 debug_info_str=None, 184 enable_mlir_converter=False): 185 """Convert `input_data_str` according to model and toco parameters. 186 187 Unless you know what you are doing consider using 188 the more friendly `tf.compat.v1.lite.toco_convert`. 189 190 Args: 191 model_flags_str: Serialized proto describing model properties, see 192 `toco/model_flags.proto`. 193 toco_flags_str: Serialized proto describing conversion properties, see 194 `toco/toco_flags.proto`. 195 input_data_str: Input data in serialized form (e.g. a graphdef is common) 196 debug_info_str: Serialized `GraphDebugInfo` proto describing logging 197 information. (default None) 198 enable_mlir_converter: Enables MLIR-based conversion instead of the default 199 TOCO conversion. (default False) 200 201 Returns: 202 Converted model in serialized form (e.g. a TFLITE model is common). 203 Raises: 204 ConverterError: When conversion fails in TFLiteConverter, usually due to 205 ops not being supported. 206 RuntimeError: When conversion fails, an exception is raised with the error 207 message embedded. 208 """ 209 # Historically, TOCO conversion failures would trigger a crash, so we would 210 # attempt to run the converter out-of-process. The MLIR conversion pipeline 211 # surfaces errors instead, and can be safely run in-process. 212 if enable_mlir_converter or not _toco_from_proto_bin: 213 try: 214 model_str = wrap_toco.wrapped_toco_convert(model_flags_str, 215 toco_flags_str, input_data_str, 216 debug_info_str, 217 enable_mlir_converter) 218 return model_str 219 except Exception as e: 220 raise ConverterError(str(e)) 221 222 if distutils.spawn.find_executable(_toco_from_proto_bin) is None: 223 raise ConverterError("""Could not find toco_from_protos binary, make sure 224your virtualenv bin directory or pip local bin directory is in your path. 225In particular, if you have installed TensorFlow with --user, make sure you 226add the install directory to your path. 227 228For example: 229Linux: export PATH=$PATH:~/.local/bin/ 230Mac: export PATH=$PATH:~/Library/Python/<version#>/bin 231 232Alternative, use virtualenv.""") 233 # Windows and TemporaryFile are not that useful together, 234 # since you cannot have two readers/writers. So we have to 235 # make the temporaries and close and delete them explicitly. 236 toco_filename, model_filename, input_filename, output_filename = (None, None, 237 None, None) 238 try: 239 # Build all input files 240 with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \ 241 _tempfile.NamedTemporaryFile(delete=False) as fp_model, \ 242 _tempfile.NamedTemporaryFile(delete=False) as fp_input, \ 243 _tempfile.NamedTemporaryFile(delete=False) as fp_debug: 244 toco_filename = fp_toco.name 245 input_filename = fp_input.name 246 model_filename = fp_model.name 247 debug_filename = fp_debug.name 248 249 fp_model.write(model_flags_str) 250 fp_toco.write(toco_flags_str) 251 fp_input.write(six.ensure_binary(input_data_str)) 252 debug_info_str = debug_info_str if debug_info_str else "" 253 # if debug_info_str contains a "string value", then the call to 254 # fp_debug.write(debug_info_str) will fail with the following error 255 # 256 # TypeError: a bytes-like object is required, not 'str' 257 # 258 # Some of the subtests within the "convert_test" unit-test fail 259 # with the error shown above. So watch out for that scenario and 260 # convert debug_info_str to bytes where needed 261 if not isinstance(debug_info_str, bytes): 262 fp_debug.write(debug_info_str.encode("utf-8")) 263 else: 264 fp_debug.write(debug_info_str) 265 266 # Reserve an output file 267 with _tempfile.NamedTemporaryFile(delete=False) as fp: 268 output_filename = fp.name 269 270 # Run 271 cmd = [ 272 _toco_from_proto_bin, 273 model_filename, 274 toco_filename, 275 input_filename, 276 output_filename, 277 "--debug_proto_file={}".format(debug_filename), 278 ] 279 if enable_mlir_converter: 280 cmd.append("--enable_mlir_converter") 281 cmdline = " ".join(cmd) 282 is_windows = _platform.system() == "Windows" 283 proc = _subprocess.Popen( 284 cmdline, 285 shell=True, 286 stdout=_subprocess.PIPE, 287 stderr=_subprocess.STDOUT, 288 close_fds=not is_windows) 289 stdout, stderr = proc.communicate() 290 exitcode = proc.returncode 291 if exitcode == 0: 292 with open(output_filename, "rb") as fp: 293 return fp.read() 294 else: 295 stdout = _try_convert_to_unicode(stdout) 296 stderr = _try_convert_to_unicode(stderr) 297 raise ConverterError("See console for info.\n%s\n%s\n" % (stdout, stderr)) 298 finally: 299 # Must manually cleanup files. 300 for filename in [ 301 toco_filename, input_filename, model_filename, output_filename 302 ]: 303 try: 304 _os.unlink(filename) 305 except (OSError, TypeError): 306 pass 307 308 309def build_toco_flags(inference_type=dtypes.float32, 310 inference_input_type=None, 311 input_format=lite_constants.TENSORFLOW_GRAPHDEF, 312 output_format=lite_constants.TFLITE, 313 default_ranges_stats=None, 314 drop_control_dependency=True, 315 reorder_across_fake_quant=False, 316 allow_custom_ops=False, 317 custom_opdefs=None, 318 post_training_quantize=False, 319 quantize_to_float16=False, 320 dump_graphviz_dir=None, 321 dump_graphviz_video=False, 322 target_ops=None, 323 conversion_summary_dir=None, 324 select_user_tf_ops=None, 325 enable_tflite_resource_variables=False, 326 **_): 327 """Build the TOCO flags object from params.""" 328 toco = _toco_flags_pb2.TocoFlags() 329 toco.input_format = input_format 330 toco.output_format = output_format 331 toco.inference_type = util.convert_dtype_to_tflite_type(inference_type) 332 if inference_input_type: 333 toco.inference_input_type = util.convert_dtype_to_tflite_type( 334 inference_input_type) 335 else: 336 toco.inference_input_type = toco.inference_type 337 toco.drop_control_dependency = drop_control_dependency 338 toco.reorder_across_fake_quant = reorder_across_fake_quant 339 toco.allow_custom_ops = allow_custom_ops 340 if custom_opdefs: 341 toco.custom_opdefs.extend(custom_opdefs) 342 if select_user_tf_ops: 343 toco.select_user_tf_ops.extend(select_user_tf_ops) 344 toco.post_training_quantize = post_training_quantize 345 toco.quantize_to_float16 = quantize_to_float16 346 if default_ranges_stats: 347 toco.default_ranges_min = default_ranges_stats[0] 348 toco.default_ranges_max = default_ranges_stats[1] 349 if dump_graphviz_dir: 350 toco.dump_graphviz_dir = dump_graphviz_dir 351 toco.dump_graphviz_include_video = dump_graphviz_video 352 if conversion_summary_dir: 353 toco.conversion_summary_dir = conversion_summary_dir 354 if target_ops: 355 if OpsSet.SELECT_TF_OPS in set(target_ops): 356 toco.enable_select_tf_ops = True 357 if set(target_ops) == set([OpsSet.SELECT_TF_OPS]): 358 toco.force_select_tf_ops = True 359 toco.enable_tflite_resource_variables = enable_tflite_resource_variables 360 return toco 361 362 363def build_toco_convert_protos(input_tensors, 364 output_tensors, 365 inference_type=dtypes.float32, 366 inference_input_type=None, 367 input_format=lite_constants.TENSORFLOW_GRAPHDEF, 368 input_shapes=None, 369 output_format=lite_constants.TFLITE, 370 quantized_input_stats=None, 371 default_ranges_stats=None, 372 drop_control_dependency=True, 373 reorder_across_fake_quant=False, 374 allow_custom_ops=False, 375 custom_opdefs=None, 376 change_concat_input_ranges=False, 377 post_training_quantize=False, 378 quantize_to_float16=False, 379 dump_graphviz_dir=None, 380 dump_graphviz_video=False, 381 target_ops=None, 382 allow_nonexistent_arrays=False, 383 debug_info=None, 384 conversion_summary_dir=None, 385 saved_model_dir=None, 386 saved_model_version=0, 387 saved_model_tags=None, 388 saved_model_exported_names=None, 389 select_user_tf_ops=None): 390 """Builds protocol buffers describing a conversion of a model using TOCO. 391 392 Typically this is to convert from TensorFlow GraphDef to TFLite, in which 393 case the default `input_format` and `output_format` are sufficient. 394 395 Args: 396 input_tensors: List of input tensors. Type and shape are computed using 397 `foo.shape` and `foo.dtype`. 398 output_tensors: List of output tensors (only .name is used from this). 399 inference_type: Data type of numeric arrays, excluding the input layer. 400 (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8}) 401 inference_input_type: Data type of the numeric arrays in the input layer. If 402 `inference_input_type` is in {tf.int8, tf.uint8}, then 403 `quantized_input_stats` must be provided. (default is the value assigned 404 to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8}) 405 input_format: Type of data to read. 406 (default TENSORFLOW_GRAPHDEF, must be in {TENSORFLOW_GRAPHDEF}) 407 input_shapes: Input array shape. (default None, must be None or a list of 408 the same length as `input_tensors`.) 409 output_format: Output file format. (default TFLITE, must be in 410 {TFLITE, GRAPHVIZ_DOT}) 411 quantized_input_stats: Map of input tensor names to a tuple of floats 412 representing the mean and standard deviation of the training data. 413 (e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8 414 or tf.uint8. (default None) 415 default_ranges_stats: Tuple of integers representing (min, max) range values 416 for all arrays without a specified range. Intended for experimenting with 417 quantization via "dummy quantization". (default None) 418 drop_control_dependency: Boolean indicating whether to drop control 419 dependencies silently. This is due to TFLite not supporting control 420 dependencies. (default True) 421 reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant 422 nodes in unexpected locations. Used when the location of the FakeQuant 423 nodes is preventing graph transformations necessary to convert the graph. 424 Results in a graph that differs from the quantized training graph, 425 potentially causing differing arithmetic behavior. (default False) 426 allow_custom_ops: Boolean indicating whether to allow custom operations. 427 When false any unknown operation is an error. When true, custom ops are 428 created for any op that is unknown. The developer will need to provide 429 these to the TensorFlow Lite runtime with a custom resolver. (default 430 False) 431 custom_opdefs: List of strings representing custom ops OpDefs that are 432 included in the GraphDef. Required when using custom operations with the 433 MLIR-based converter. (default None) 434 change_concat_input_ranges: Boolean to change behavior of min/max ranges for 435 inputs and outputs of the concat operator for quantized models. Changes 436 the ranges of concat operator overlap when true. (default False) 437 post_training_quantize: Boolean indicating whether to quantize the weights 438 of the converted float model. Model size will be reduced and there will be 439 latency improvements (at the cost of accuracy). (default False) 440 quantize_to_float16: Boolean indicating whether to convert float buffers to 441 float16. (default False) 442 dump_graphviz_dir: Full filepath of folder to dump the graphs at various 443 stages of processing GraphViz .dot files. Preferred over 444 --output_format=GRAPHVIZ_DOT in order to keep the requirements of the 445 output file. (default None) 446 dump_graphviz_video: Boolean indicating whether to dump the graph after 447 every graph transformation. (default False) 448 target_ops: Experimental flag, subject to change. Set of OpsSet options 449 indicating which converter to use. (default set([OpsSet.TFLITE_BUILTINS])) 450 allow_nonexistent_arrays: Allow specifying array names that don't exist or 451 are unused in the final graph. (default False) 452 debug_info: `GraphDebugInfo` proto containing the stack traces for the 453 original nodes referred by the converted graph. 454 conversion_summary_dir: A string, the path to the generated conversion logs. 455 saved_model_dir: Filepath of the saved model to be converted. This value 456 will be non-empty only when the saved model import path will be used. 457 Otherwises, the graph def-based conversion will be processed. 458 saved_model_version: SavedModel file format version of The saved model file 459 to be converted. This value will be set only when the SavedModel import 460 path will be used. 461 saved_model_tags: Set of string saved model tags, formatted in the 462 comma-separated value. This value will be set only when the SavedModel 463 import path will be used. 464 saved_model_exported_names: Names to be exported (default: export all) when 465 the saved model import path is on. This value will be set only when the 466 SavedModel import path will be used. 467 select_user_tf_ops: List of user's defined TensorFlow ops need to be 468 supported in the TensorFlow Lite runtime. These ops will be supported as 469 select TensorFlow ops. 470 471 Returns: 472 model_flags, toco_flags, debug_info: three protocol buffers describing the 473 conversion process and debug information. 474 475 Raises: 476 ValueError: 477 If the input tensor type is unknown 478 Missing mean_values or std_dev_values 479 RuntimeError: If TOCO fails to convert (in which case the runtime error's 480 error text will contain the TOCO error log) 481 """ 482 toco = build_toco_flags(inference_type, inference_input_type, input_format, 483 output_format, default_ranges_stats, 484 drop_control_dependency, reorder_across_fake_quant, 485 allow_custom_ops, custom_opdefs, 486 post_training_quantize, quantize_to_float16, 487 dump_graphviz_dir, dump_graphviz_video, target_ops, 488 conversion_summary_dir, select_user_tf_ops) 489 model = _model_flags_pb2.ModelFlags() 490 model.change_concat_input_ranges = change_concat_input_ranges 491 for idx, input_tensor in enumerate(input_tensors): 492 input_array = model.input_arrays.add() 493 if saved_model_dir: 494 input_array.name = input_tensor.name 495 else: 496 input_array.name = util.get_tensor_name(input_tensor) 497 input_array.data_type = util.convert_dtype_to_tflite_type( 498 input_tensor.dtype) 499 500 if _requires_input_stats(toco) and quantized_input_stats: 501 input_array.mean_value, input_array.std_value = quantized_input_stats[idx] 502 503 if input_shapes is None: 504 shape = input_tensor.shape 505 else: 506 shape = input_shapes[idx] 507 508 if shape.rank is not None: 509 # Create shapes with -1 for unknown dimensions. 510 dims = [] 511 for dim in shape: 512 if (dim is None or 513 (isinstance(dim, tensor_shape.Dimension) and dim.value is None)): 514 dims.append(-1) 515 else: 516 dims.append(int(dim)) 517 input_array.shape.dims.extend(dims) 518 input_array.shape.unknown_rank = False 519 else: 520 input_array.shape.unknown_rank = True 521 522 for output_tensor in output_tensors: 523 if saved_model_dir: 524 model.output_arrays.append(output_tensor.name) 525 else: 526 model.output_arrays.append(util.get_tensor_name(output_tensor)) 527 528 model.allow_nonexistent_arrays = allow_nonexistent_arrays 529 530 if saved_model_dir: 531 model.saved_model_dir = saved_model_dir 532 model.saved_model_version = saved_model_version 533 if saved_model_tags: 534 model.saved_model_tags.extend(saved_model_tags) 535 if saved_model_exported_names: 536 model.saved_model_exported_names.extend(saved_model_exported_names) 537 538 return model, toco, debug_info 539 540 541def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays, 542 enable_mlir_converter, *args, **kwargs): 543 """"Convert a model using TOCO. 544 545 This function is used to convert GraphDefs that cannot be loaded into 546 TensorFlow to TFLite. Conversion can be customized by providing arguments 547 that are forwarded to `build_toco_convert_protos` (see documentation for 548 details). 549 550 Args: 551 input_data: Input data (i.e. often `sess.graph_def`), 552 input_arrays_with_shape: Tuple of strings representing input tensor names 553 and list of integers representing input shapes 554 (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded 555 into TensorFlow and when `input_tensors` is None. (default None) 556 output_arrays: List of output tensors to freeze graph with. Use only when 557 graph cannot be loaded into TensorFlow and when `output_tensors` is None. 558 (default None) 559 enable_mlir_converter: Enables MLIR-based conversion instead of TOCO 560 conversion. 561 *args: See `build_toco_convert_protos`, 562 **kwargs: See `build_toco_convert_protos`. 563 564 Returns: 565 The converted data. For example if TFLite was the destination, then 566 this will be a tflite flatbuffer in a bytes array. 567 568 Raises: 569 Defined in `build_toco_convert_protos`. 570 """ 571 model_flags, toco_flags, _ = build_toco_convert_protos( 572 input_tensors=[], output_tensors=[], *args, **kwargs) 573 574 for idx, (name, shape) in enumerate(input_arrays_with_shape): 575 input_array = model_flags.input_arrays.add() 576 if _requires_input_stats(toco_flags): 577 if (("quantized_input_stats" not in kwargs) or 578 (not kwargs["quantized_input_stats"])): 579 raise ValueError( 580 "The `quantized_input_stats` flag must be defined when either " 581 "`inference_type` flag or `inference_input_type` flag is set to " 582 "tf.int8 or tf.uint8.") 583 input_array.mean_value, input_array.std_value = kwargs[ 584 "quantized_input_stats"][idx] 585 input_array.name = name 586 input_array.shape.dims.extend(list(map(int, shape))) 587 588 for name in output_arrays: 589 model_flags.output_arrays.append(name) 590 591 data = toco_convert_protos( 592 model_flags.SerializeToString(), 593 toco_flags.SerializeToString(), 594 input_data.SerializeToString(), 595 enable_mlir_converter=enable_mlir_converter) 596 return data 597 598 599def toco_convert_impl(input_data, input_tensors, output_tensors, 600 enable_mlir_converter, *args, **kwargs): 601 """"Convert a model using TOCO. 602 603 Typically this function is used to convert from TensorFlow GraphDef to TFLite. 604 Conversion can be customized by providing arguments that are forwarded to 605 `build_toco_convert_protos` (see documentation for details). 606 607 Args: 608 input_data: Input data (i.e. often `sess.graph_def`), 609 input_tensors: List of input tensors. Type and shape are computed using 610 `foo.shape` and `foo.dtype`. 611 output_tensors: List of output tensors (only .name is used from this). 612 enable_mlir_converter: Enables MLIR-based conversion instead of TOCO 613 conversion. 614 *args: See `build_toco_convert_protos`, 615 **kwargs: See `build_toco_convert_protos`. 616 617 Returns: 618 The converted data. For example if TFLite was the destination, then 619 this will be a tflite flatbuffer in a bytes array. 620 621 Raises: 622 Defined in `build_toco_convert_protos`. 623 """ 624 model_flags, toco_flags, debug_info = build_toco_convert_protos( 625 input_tensors, output_tensors, *args, **kwargs) 626 debug_info_str = debug_info.SerializeToString() if debug_info else None 627 data = toco_convert_protos( 628 model_flags.SerializeToString(), 629 toco_flags.SerializeToString(), 630 input_data.SerializeToString(), 631 debug_info_str=debug_info_str, 632 enable_mlir_converter=enable_mlir_converter) 633 return data 634 635 636def convert_saved_model(saved_model_dir=None, 637 saved_model_version=0, 638 saved_model_tags=None, 639 saved_model_exported_names=None, 640 **kwargs): 641 """Converts a saved_model using TF Lite converter.""" 642 model_flags = _model_flags_pb2.ModelFlags() 643 if saved_model_dir: 644 model_flags.saved_model_dir = saved_model_dir 645 model_flags.saved_model_version = saved_model_version 646 if saved_model_tags: 647 model_flags.saved_model_tags.extend(saved_model_tags) 648 if saved_model_exported_names: 649 model_flags.saved_model_exported_names.extend(saved_model_exported_names) 650 toco_flags = build_toco_flags(**kwargs) 651 data = toco_convert_protos( 652 model_flags.SerializeToString(), 653 toco_flags.SerializeToString(), 654 None, # input_data, unused 655 None, # debug_info_str, unused 656 enable_mlir_converter=True) 657 return data 658 659 660@_tf_export(v1=["lite.toco_convert"]) 661@deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.") 662def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): 663 """Convert a model using TOCO. 664 665 Typically this function is used to convert from TensorFlow GraphDef to TFLite. 666 Conversion can be customized by providing arguments that are forwarded to 667 `build_toco_convert_protos` (see documentation for details). This function has 668 been deprecated. Please use `tf.lite.TFLiteConverter` instead. 669 670 Args: 671 input_data: Input data (i.e. often `sess.graph_def`), 672 input_tensors: List of input tensors. Type and shape are computed using 673 `foo.shape` and `foo.dtype`. 674 output_tensors: List of output tensors (only .name is used from this). 675 *args: See `build_toco_convert_protos`, 676 **kwargs: See `build_toco_convert_protos`. 677 678 Returns: 679 The converted data. For example if TFLite was the destination, then 680 this will be a tflite flatbuffer in a bytes array. 681 682 Raises: 683 Defined in `build_toco_convert_protos`. 684 """ 685 enable_mlir_converter = kwargs.get("enable_mlir_converter", False) 686 return toco_convert_impl(input_data, input_tensors, output_tensors, 687 enable_mlir_converter, *args, **kwargs) 688