1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow Lite tooling helper functionality.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import warnings 22import enum 23from six import PY3 24 25from google.protobuf import text_format as _text_format 26from google.protobuf.message import DecodeError 27from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn # pylint: disable=unused-import 28from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell # pylint: disable=unused-import 29from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell # pylint: disable=unused-import 30from tensorflow.lite.python import lite_constants as constants 31from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import 32from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import 33from tensorflow.lite.python.convert import OpsSet 34from tensorflow.lite.python.convert import tensor_name as _tensor_name 35from tensorflow.lite.python.convert import toco_convert # pylint: disable=unused-import 36from tensorflow.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def 37from tensorflow.lite.python.convert import toco_convert_impl as _toco_convert_impl 38from tensorflow.lite.python.convert import toco_convert_protos # pylint: disable=unused-import 39from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model 40from tensorflow.lite.python.convert_saved_model import get_tensors_from_tensor_names as _get_tensors_from_tensor_names 41from tensorflow.lite.python.convert_saved_model import set_tensor_shapes as _set_tensor_shapes 42from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import 43from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import 44from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import 45from tensorflow.lite.python.optimize import calibrator as _calibrator 46from tensorflow.core.framework import graph_pb2 as _graph_pb2 47from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2 48from tensorflow.core.protobuf import config_pb2 as _config_pb2 49from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 50from tensorflow.python import keras as _keras 51from tensorflow.python.client import session as _session 52from tensorflow.python.eager import def_function as _def_function 53from tensorflow.python.eager import function as _function 54from tensorflow.python.framework import convert_to_constants as _convert_to_constants 55from tensorflow.python.framework import dtypes as _dtypes 56from tensorflow.python.framework import graph_util as _tf_graph_util 57from tensorflow.python.framework import ops as _ops 58from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError 59from tensorflow.python.framework.importer import import_graph_def as _import_graph_def 60from tensorflow.python.grappler import tf_optimizer as _tf_optimizer 61from tensorflow.python.lib.io import file_io as _file_io 62from tensorflow.python.saved_model import signature_constants as _signature_constants 63from tensorflow.python.saved_model import tag_constants as _tag_constants 64from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph 65from tensorflow.python.util import deprecation as _deprecation 66from tensorflow.python.util.tf_export import tf_export as _tf_export 67 68 69def _run_graph_optimizations(graph_def, input_arrays, output_arrays, 70 graph=None): 71 """Apply standard TensorFlow optimizations to the graph_def. 72 73 Args: 74 graph_def: Frozen GraphDef to be optimized. 75 input_arrays: List of arrays that are considered inputs of the graph. 76 output_arrays: List of arrays that are considered outputs of the graph. 77 graph: TensorFlow Graph. Required when Eager mode is enabled. (default None) 78 79 Returns: 80 A new, optimized GraphDef. 81 """ 82 meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph) 83 84 # We need to add a collection called 'train_op' so that grappler 85 # knows what the outputs are. 86 fetch_collection = _meta_graph_pb2.CollectionDef() 87 for array in input_arrays + output_arrays: 88 fetch_collection.node_list.value.append(array.name) 89 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) 90 91 config = _config_pb2.ConfigProto() 92 rewrite_options = config.graph_options.rewrite_options 93 rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.ON 94 # Avoid remapping as it creates ops like _FusedConv2D, which are not 95 # supported by TF Lite. 96 rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF 97 return _tf_optimizer.OptimizeGraph(config, meta_graph) 98 99 100@_tf_export("lite.Optimize") 101class Optimize(enum.Enum): 102 """Enum defining the optimizations to apply when generating tflite graphs. 103 104 Some optimizations may come at the cost of accuracy. 105 """ 106 107 # Optimize for size. 108 # 109 # Optimizations that reduce the size of the model. 110 # The model size will be reduced. Optimizations can include quantizing the 111 # weights of the floating point model. 112 OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE" 113 114 # Optimize for latency. 115 # 116 # Optimizations that reduce the latency of the model. 117 # The model latency will be reduced. Optimizations can include quantizing the 118 # weights of the floating point model. 119 OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY" 120 121 def __str__(self): 122 return self.value 123 124 125@_tf_export("lite.RepresentativeDataset") 126class RepresentativeDataset(object): 127 """Representative dataset to evaluate optimizations. 128 129 A representative dataset that can be used to evaluate optimizations by the 130 converter. E.g. converter can use these examples to estimate (min, max) ranges 131 by calibrating the model on inputs. This can allow converter to quantize a 132 converted floating point model. 133 """ 134 135 def __init__(self, input_gen, output_gen=None): 136 """Creates a representative dataset. 137 138 Args: 139 input_gen: an input generator that can be used to generate input samples 140 for the model. This must be a callable object that returns an object 141 that supports the `iter()` protocol (e.g. a generator function). The 142 elements generated must have same type and shape as inputs to the model. 143 output_gen: (optional) an output generator that can be used to generate 144 output samples for the model. This must be a callable object that 145 returns an object that supports the `iter()` protocol (e.g. a generator 146 function). The elements generated must have same type and shape as 147 outputs to the model. (default None) 148 """ 149 self.input_gen = input_gen 150 self.output_gen = output_gen 151 152 153@_tf_export("lite.TargetSpec") 154class TargetSpec(object): 155 """Specification of target device. 156 157 Details about target device. Converter optimizes the generated model for 158 specific device. 159 160 Attributes: 161 supported_ops: Experimental flag, subject to change. Set of OpsSet options 162 supported by the device. (default set([OpsSet.TFLITE_BUILTINS])) 163 """ 164 165 def __init__(self, supported_ops=None): 166 if supported_ops is None: 167 supported_ops = set([OpsSet.TFLITE_BUILTINS]) 168 self.supported_ops = supported_ops 169 170 171@_tf_export("lite.TFLiteConverter", v1=[]) 172class TFLiteConverterV2(object): 173 """Converts a TensorFlow model into TensorFlow Lite model. 174 175 Attributes: 176 allow_custom_ops: Boolean indicating whether to allow custom operations. 177 When false any unknown operation is an error. When true, custom ops are 178 created for any op that is unknown. The developer will need to provide 179 these to the TensorFlow Lite runtime with a custom resolver. (default 180 False) 181 target_spec: Experimental flag, subject to change. Specification of target 182 device. 183 optimizations: Experimental flag, subject to change, A list of optimizations 184 to apply when converting the model. The converter applies the 185 optimizations by giving priority to the optimizations specified earlier in 186 the list. E.g. `[optimize.OPTIMIZE_FOR_SIZE, 187 optimize.OPTIMIZE_FOR_LATENCY]` requires the converter to do both size and 188 latency optimizations giving priority to size optimizations over latency 189 optimizations. 190 representative_dataset: A representative dataset that can be used to 191 generate input and output samples for the model. The converter can use the 192 dataset to evaluate different optimizations. 193 194 Example usage: 195 196 ```python 197 # Converting a GraphDef from a ConcreteFunction. 198 converter = lite.TFLiteConverter.from_concrete_function(func) 199 tflite_model = converter.convert() 200 open("converted_model.tflite", "wb").write(tflite_model) 201 ``` 202 """ 203 204 def __init__(self, func): 205 """Constructor for TFLiteConverter. 206 207 Args: 208 func: TensorFlow ConcreteFunction. 209 """ 210 self._func = func 211 self.allow_custom_ops = False 212 self.target_spec = TargetSpec() 213 self.representative_dataset = None 214 self.optimizations = [] 215 216 @classmethod 217 def from_concrete_function(cls, func): 218 """Creates a TFLiteConverter class from a ConcreteFunction. 219 220 Args: 221 func: TensorFlow ConcreteFunction. 222 223 Returns: 224 TFLiteConverter class. 225 """ 226 if not isinstance(func, _function.ConcreteFunction): 227 message = "This function takes in a ConcreteFunction." 228 if isinstance(func, _def_function.Function): 229 message += (" To get the ConcreteFunction from a Function," 230 " call from_concrete_function.") 231 raise ValueError(message) 232 return cls(func) 233 234 def convert(self): 235 """Converts a TensorFlow GraphDef based on instance variables. 236 237 Returns: 238 The converted data in serialized format. 239 240 Raises: 241 ValueError: 242 Input shape is not specified. 243 None value for dimension in input_tensor. 244 """ 245 frozen_func = _convert_to_constants.convert_variables_to_constants_v2( 246 self._func) 247 input_tensors = [ 248 tensor for tensor in frozen_func.inputs 249 if tensor.dtype != _dtypes.resource 250 ] 251 output_tensors = frozen_func.outputs 252 253 # Run a Grappler pass. 254 graph_def = _run_graph_optimizations(frozen_func.graph.as_graph_def(), 255 input_tensors, output_tensors, 256 frozen_func.graph) 257 258 # Checks dimensions in input tensor. 259 for tensor in input_tensors: 260 # Note that shape_list might be empty for scalar shapes. 261 shape_list = tensor.shape.as_list() 262 if None in shape_list[1:]: 263 raise ValueError( 264 "None is only supported in the 1st dimension. Tensor '{0}' has " 265 "invalid shape '{1}'.".format(_tensor_name(tensor), shape_list)) 266 elif shape_list and shape_list[0] is None: 267 # Set the batch size to 1 if undefined. 268 shape = tensor.shape.as_list() 269 shape[0] = 1 270 tensor.set_shape(shape) 271 272 if self.representative_dataset: 273 if not isinstance(self.representative_dataset, RepresentativeDataset): 274 raise TypeError("`representative_dataset` must be an instance of " 275 "`RepresentativeDataset`") 276 if self.representative_dataset.input_gen is None: 277 raise ValueError( 278 "Provide an input generator for `representative_dataset`") 279 280 # TODO(shashishekhar): For now use optimizations order is ignored. 281 # Both size and latency optimizations decide whether to apply post 282 # training optimizations. 283 post_training_optimize = bool( 284 len( 285 set(self.optimizations) 286 & set([Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE]))) 287 # Do weights only quantization if there is no dataset for calibration. 288 weights_only_quantize_flag = ( 289 post_training_optimize and (self.representative_dataset is None)) 290 291 converter_kwargs = { 292 "input_format": constants.TENSORFLOW_GRAPHDEF, 293 "allow_custom_ops": self.allow_custom_ops, 294 "post_training_quantize": weights_only_quantize_flag, 295 "target_ops": self.target_spec.supported_ops, 296 } 297 298 # Converts model. 299 result = _toco_convert_impl( 300 input_data=graph_def, 301 input_tensors=input_tensors, 302 output_tensors=output_tensors, 303 **converter_kwargs) 304 305 if self.representative_dataset and post_training_optimize: 306 calibrate_quantize = _calibrator.Calibrator(result) 307 result = calibrate_quantize.calibrate_and_quantize( 308 self.representative_dataset.input_gen) 309 310 return result 311 312 313@_tf_export(v1=["lite.TFLiteConverter"]) 314class TFLiteConverter(object): 315 """Convert a TensorFlow model into `output_format` using TOCO. 316 317 This is used to convert from a TensorFlow GraphDef or SavedModel into either a 318 TFLite FlatBuffer or graph visualization. 319 320 Attributes: 321 322 inference_type: Target data type of real-number arrays in the output file. 323 Must be `{tf.float32, tf.uint8}`. (default tf.float32) 324 inference_input_type: Target data type of real-number input arrays. Allows 325 for a different type for input arrays in the case of quantization. 326 Must be `{tf.float32, tf.uint8}`. (default `inference_type`) 327 output_format: Output file format. Currently must be `{TFLITE, 328 GRAPHVIZ_DOT}`. (default TFLITE) 329 quantized_input_stats: Dict of strings representing input tensor names 330 mapped to tuple of floats representing the mean and standard deviation 331 of the training data (e.g., {"foo" : (0., 1.)}). Only need if 332 `inference_input_type` is `QUANTIZED_UINT8`. 333 real_input_value = (quantized_input_value - mean_value) / std_dev_value. 334 (default {}) 335 default_ranges_stats: Tuple of integers representing (min, max) range values 336 for all arrays without a specified range. Intended for experimenting with 337 quantization via "dummy quantization". (default None) 338 drop_control_dependency: Boolean indicating whether to drop control 339 dependencies silently. This is due to TFLite not supporting control 340 dependencies. (default True) 341 reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant 342 nodes in unexpected locations. Used when the location of the FakeQuant 343 nodes is preventing graph transformations necessary to convert the graph. 344 Results in a graph that differs from the quantized training graph, 345 potentially causing differing arithmetic behavior. (default False) 346 change_concat_input_ranges: Boolean to change behavior of min/max ranges for 347 inputs and outputs of the concat operator for quantized models. Changes 348 the ranges of concat operator overlap when true. (default False) 349 allow_custom_ops: Boolean indicating whether to allow custom operations. 350 When false any unknown operation is an error. When true, custom ops are 351 created for any op that is unknown. The developer will need to provide 352 these to the TensorFlow Lite runtime with a custom resolver. 353 (default False) 354 post_training_quantize: deprecated, please specify 355 `[optimize.OPTIMIZE_FOR_SIZE]` for `optimizations` instead. Boolean 356 indicating whether to quantize the weights of the converted float model. 357 Model size will be reduced and there will be latency improvements 358 (at the cost of accuracy). (default False) 359 dump_graphviz_dir: Full filepath of folder to dump the graphs at various 360 stages of processing GraphViz .dot files. Preferred over 361 --output_format=GRAPHVIZ_DOT in order to keep the requirements of the 362 output file. (default None) 363 dump_graphviz_video: Boolean indicating whether to dump the graph after 364 every graph transformation. (default False) 365 target_ops: Experimental flag, subject to change. Set of OpsSet 366 options indicating which converter to use. 367 (default set([OpsSet.TFLITE_BUILTINS])) 368 optimizations: Experimental flag, subject to change, A list of 369 optimizations to apply when converting the model. The converter applies 370 the optimizations by giving priority to the optimizations specified 371 earlier in the list. E.g. 372 `[optimize.OPTIMIZE_FOR_SIZE, optimize.OPTIMIZE_FOR_LATENCY]` requires 373 the converter to do both size and latency optimizations giving priority 374 to size optimizations over latency optimizations. 375 representative_dataset: A representative dataset that can be used to 376 generate input and output samples for the model. The converter can use 377 the dataset to evaluate different optimizations. 378 379 Example usage: 380 381 ```python 382 # Converting a GraphDef from session. 383 converter = lite.TFLiteConverter.from_session(sess, in_tensors, out_tensors) 384 tflite_model = converter.convert() 385 open("converted_model.tflite", "wb").write(tflite_model) 386 387 # Converting a GraphDef from file. 388 converter = lite.TFLiteConverter.from_frozen_graph( 389 graph_def_file, input_arrays, output_arrays) 390 tflite_model = converter.convert() 391 open("converted_model.tflite", "wb").write(tflite_model) 392 393 # Converting a SavedModel. 394 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 395 tflite_model = converter.convert() 396 397 # Converting a tf.keras model. 398 converter = lite.TFLiteConverter.from_keras_model_file(keras_model) 399 tflite_model = converter.convert() 400 ``` 401 """ 402 403 def __init__(self, 404 graph_def, 405 input_tensors, 406 output_tensors, 407 input_arrays_with_shape=None, 408 output_arrays=None): 409 """Constructor for TFLiteConverter. 410 411 Args: 412 graph_def: Frozen TensorFlow GraphDef. 413 input_tensors: List of input tensors. Type and shape are computed using 414 `foo.shape` and `foo.dtype`. 415 output_tensors: List of output tensors (only .name is used from this). 416 input_arrays_with_shape: Tuple of strings representing input tensor names 417 and list of integers representing input shapes 418 (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded 419 into TensorFlow and when `input_tensors` and `output_tensors` are 420 None. (default None) 421 output_arrays: List of output tensors to freeze graph with. Use only when 422 graph cannot be loaded into TensorFlow and when `input_tensors` and 423 `output_tensors` are None. (default None) 424 425 Raises: 426 ValueError: Invalid arguments. 427 """ 428 self._graph_def = graph_def 429 self._input_tensors = input_tensors 430 self._output_tensors = output_tensors 431 self.inference_type = constants.FLOAT 432 self.inference_input_type = None 433 self.output_format = constants.TFLITE 434 self.quantized_input_stats = {} 435 self.default_ranges_stats = None 436 self.drop_control_dependency = True 437 self.reorder_across_fake_quant = False 438 self.change_concat_input_ranges = False 439 self.allow_custom_ops = False 440 self._post_training_quantize = False 441 self.dump_graphviz_dir = None 442 self.dump_graphviz_video = False 443 self.target_ops = set([OpsSet.TFLITE_BUILTINS]) 444 self.representative_dataset = None 445 self.optimizations = [] 446 447 # Attributes are used by models that cannot be loaded into TensorFlow. 448 if not self._has_valid_tensors(): 449 if not input_arrays_with_shape or not output_arrays: 450 raise ValueError( 451 "If input_tensors and output_tensors are None, both " 452 "input_arrays_with_shape and output_arrays must be defined.") 453 self._input_arrays_with_shape = input_arrays_with_shape 454 self._output_arrays = output_arrays 455 456 @classmethod 457 def from_session(cls, sess, input_tensors, output_tensors): 458 """Creates a TFLiteConverter class from a TensorFlow Session. 459 460 Args: 461 sess: TensorFlow Session. 462 input_tensors: List of input tensors. Type and shape are computed using 463 `foo.shape` and `foo.dtype`. 464 output_tensors: List of output tensors (only .name is used from this). 465 466 Returns: 467 TFLiteConverter class. 468 """ 469 graph_def = _freeze_graph(sess, output_tensors) 470 return cls(graph_def, input_tensors, output_tensors) 471 472 @classmethod 473 def from_frozen_graph(cls, 474 graph_def_file, 475 input_arrays, 476 output_arrays, 477 input_shapes=None): 478 """Creates a TFLiteConverter class from a file containing a frozen GraphDef. 479 480 Args: 481 graph_def_file: Full filepath of file containing frozen GraphDef. 482 input_arrays: List of input tensors to freeze graph with. 483 output_arrays: List of output tensors to freeze graph with. 484 input_shapes: Dict of strings representing input tensor names to list of 485 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 486 Automatically determined when input shapes is None (e.g., {"foo" : 487 None}). (default None) 488 489 Returns: 490 TFLiteConverter class. 491 492 Raises: 493 IOError: 494 File not found. 495 Unable to parse input file. 496 ValueError: 497 The graph is not frozen. 498 input_arrays or output_arrays contains an invalid tensor name. 499 input_shapes is not correctly defined when required 500 """ 501 with _ops.Graph().as_default(): 502 with _session.Session() as sess: 503 # Read GraphDef from file. 504 if not _file_io.file_exists(graph_def_file): 505 raise IOError("File '{0}' does not exist.".format(graph_def_file)) 506 with _file_io.FileIO(graph_def_file, "rb") as f: 507 file_content = f.read() 508 509 try: 510 graph_def = _graph_pb2.GraphDef() 511 graph_def.ParseFromString(file_content) 512 except (_text_format.ParseError, DecodeError): 513 try: 514 print("Ignore 'tcmalloc: large alloc' warnings.") 515 516 if not isinstance(file_content, str): 517 if PY3: 518 file_content = file_content.decode("utf-8") 519 else: 520 file_content = file_content.encode("utf-8") 521 graph_def = _graph_pb2.GraphDef() 522 _text_format.Merge(file_content, graph_def) 523 except (_text_format.ParseError, DecodeError): 524 raise IOError( 525 "Unable to parse input file '{}'.".format(graph_def_file)) 526 527 # Handles models with custom TFLite ops that cannot be resolved in 528 # TensorFlow. 529 load_model_in_session = True 530 try: 531 _import_graph_def(graph_def, name="") 532 except _NotFoundError: 533 load_model_in_session = False 534 535 if load_model_in_session: 536 # Check if graph is frozen. 537 if not _is_frozen_graph(sess): 538 raise ValueError("Please freeze the graph using freeze_graph.py.") 539 540 # Get input and output tensors. 541 input_tensors = _get_tensors_from_tensor_names( 542 sess.graph, input_arrays) 543 output_tensors = _get_tensors_from_tensor_names( 544 sess.graph, output_arrays) 545 _set_tensor_shapes(input_tensors, input_shapes) 546 547 return cls(sess.graph_def, input_tensors, output_tensors) 548 else: 549 if not input_shapes: 550 raise ValueError("input_shapes must be defined for this model.") 551 if set(input_arrays) != set(input_shapes.keys()): 552 raise ValueError("input_shapes must contain a value for each item " 553 "in input_array.") 554 555 input_arrays_with_shape = [ 556 (name, input_shapes[name]) for name in input_arrays 557 ] 558 return cls( 559 graph_def, 560 input_tensors=None, 561 output_tensors=None, 562 input_arrays_with_shape=input_arrays_with_shape, 563 output_arrays=output_arrays) 564 565 @classmethod 566 def from_saved_model(cls, 567 saved_model_dir, 568 input_arrays=None, 569 input_shapes=None, 570 output_arrays=None, 571 tag_set=None, 572 signature_key=None): 573 """Creates a TFLiteConverter class from a SavedModel. 574 575 Args: 576 saved_model_dir: SavedModel directory to convert. 577 input_arrays: List of input tensors to freeze graph with. Uses input 578 arrays from SignatureDef when none are provided. (default None) 579 input_shapes: Dict of strings representing input tensor names to list of 580 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 581 Automatically determined when input shapes is None (e.g., {"foo" : 582 None}). (default None) 583 output_arrays: List of output tensors to freeze graph with. Uses output 584 arrays from SignatureDef when none are provided. (default None) 585 tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to 586 analyze. All tags in the tag set must be present. (default set("serve")) 587 signature_key: Key identifying SignatureDef containing inputs and outputs. 588 (default DEFAULT_SERVING_SIGNATURE_DEF_KEY) 589 590 Returns: 591 TFLiteConverter class. 592 """ 593 if tag_set is None: 594 tag_set = set([_tag_constants.SERVING]) 595 if signature_key is None: 596 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 597 598 result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, 599 output_arrays, tag_set, signature_key) 600 return cls( 601 graph_def=result[0], input_tensors=result[1], output_tensors=result[2]) 602 603 @classmethod 604 def from_keras_model_file(cls, 605 model_file, 606 input_arrays=None, 607 input_shapes=None, 608 output_arrays=None): 609 """Creates a TFLiteConverter class from a tf.keras model file. 610 611 Args: 612 model_file: Full filepath of HDF5 file containing the tf.keras model. 613 input_arrays: List of input tensors to freeze graph with. Uses input 614 arrays from SignatureDef when none are provided. (default None) 615 input_shapes: Dict of strings representing input tensor names to list of 616 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 617 Automatically determined when input shapes is None (e.g., {"foo" : 618 None}). (default None) 619 output_arrays: List of output tensors to freeze graph with. Uses output 620 arrays from SignatureDef when none are provided. (default None) 621 622 Returns: 623 TFLiteConverter class. 624 """ 625 _keras.backend.clear_session() 626 _keras.backend.set_learning_phase(False) 627 keras_model = _keras.models.load_model(model_file) 628 sess = _keras.backend.get_session() 629 630 # Get input and output tensors. 631 if input_arrays: 632 input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays) 633 else: 634 input_tensors = keras_model.inputs 635 636 if output_arrays: 637 output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays) 638 else: 639 output_tensors = keras_model.outputs 640 _set_tensor_shapes(input_tensors, input_shapes) 641 642 graph_def = _freeze_graph(sess, output_tensors) 643 return cls(graph_def, input_tensors, output_tensors) 644 645 def __setattr__(self, name, value): 646 if name == "post_training_quantize": 647 warnings.warn("Property %s is deprecated, " 648 "please use optimizations=[Optimize.OPTIMIZE_FOR_SIZE]" 649 " instead." % name) 650 if value: 651 # Use OPTIMIZE_FOR_SIZE for post training for now. 652 self.optimizations = [Optimize.OPTIMIZE_FOR_SIZE] 653 else: 654 self.optimizations = [] 655 return 656 object.__setattr__(self, name, value) 657 658 def __getattribute__(self, name): 659 if name == "post_training_quantize": 660 warnings.warn("Property %s is deprecated, " 661 "please use optimizations=[Optimize.OPTIMIZE_FOR_SIZE]" 662 " instead." % name) 663 return Optimize.OPTIMIZE_FOR_SIZE in set(self.optimizations) 664 return object.__getattribute__(self, name) 665 666 def convert(self): 667 """Converts a TensorFlow GraphDef based on instance variables. 668 669 Returns: 670 The converted data in serialized format. Either a TFLite Flatbuffer or a 671 Graphviz graph depending on value in `output_format`. 672 673 Raises: 674 ValueError: 675 Input shape is not specified. 676 None value for dimension in input_tensor. 677 """ 678 # Checks dimensions in input tensor. 679 if self._has_valid_tensors(): 680 for tensor in self._input_tensors: 681 shape = tensor.shape 682 if not shape: 683 raise ValueError("Provide an input shape for input array " 684 "'{0}'.".format(_tensor_name(tensor))) 685 # Note that shape_list might be empty for scalar shapes. 686 shape_list = shape.as_list() 687 if None in shape_list[1:]: 688 raise ValueError( 689 "None is only supported in the 1st dimension. Tensor '{0}' has " 690 "invalid shape '{1}'.".format(_tensor_name(tensor), shape_list)) 691 elif shape_list and shape_list[0] is None: 692 self._set_batch_size(batch_size=1) 693 694 # Get quantization stats. Ensures there is one stat per name if the stats 695 # are specified. 696 if self.quantized_input_stats: 697 quantized_stats = [] 698 invalid_stats = [] 699 for name in self.get_input_arrays(): 700 if name in self.quantized_input_stats: 701 quantized_stats.append(self.quantized_input_stats[name]) 702 else: 703 invalid_stats.append(name) 704 705 if invalid_stats: 706 raise ValueError("Quantization input stats are not available for input " 707 "tensors '{0}'.".format(",".join(invalid_stats))) 708 else: 709 quantized_stats = None 710 if self.representative_dataset: 711 if not isinstance(self.representative_dataset, RepresentativeDataset): 712 raise TypeError( 713 "representative_dataset must be an instance of " 714 "RepresentativeDataset") 715 if self.representative_dataset.input_gen is None: 716 raise ValueError( 717 "Provide an input generator for representative_dataset") 718 719 # TODO(shashishekhar): For now use optimizations order is ignored. 720 # Both size and latency optimizations decide whether to apply post 721 # training optimizations. 722 post_training_optimize = bool( 723 len(set(self.optimizations) & set([Optimize.OPTIMIZE_FOR_LATENCY, 724 Optimize.OPTIMIZE_FOR_SIZE]))) 725 # Do weights only quantization if there is no dataset for calibration. 726 weights_only_quantize_flag = ( 727 post_training_optimize and (self.representative_dataset is None)) 728 729 converter_kwargs = { 730 "inference_type": self.inference_type, 731 "inference_input_type": self.inference_input_type, 732 "input_format": constants.TENSORFLOW_GRAPHDEF, 733 "output_format": self.output_format, 734 "quantized_input_stats": quantized_stats, 735 "default_ranges_stats": self.default_ranges_stats, 736 "drop_control_dependency": self.drop_control_dependency, 737 "reorder_across_fake_quant": self.reorder_across_fake_quant, 738 "change_concat_input_ranges": self.change_concat_input_ranges, 739 "allow_custom_ops": self.allow_custom_ops, 740 "post_training_quantize": weights_only_quantize_flag, 741 "target_ops": self.target_ops, 742 "dump_graphviz_dir": self.dump_graphviz_dir, 743 "dump_graphviz_video": self.dump_graphviz_video 744 } 745 746 optimized_graph = None 747 if self.inference_type == constants.QUANTIZED_UINT8: 748 optimized_graph = self._graph_def 749 else: 750 try: 751 optimized_graph = _run_graph_optimizations( 752 self._graph_def, self._input_tensors, self._output_tensors) 753 except Exception: 754 optimized_graph = self._graph_def 755 756 # Converts model. 757 if self._has_valid_tensors(): 758 result = _toco_convert_impl( 759 input_data=optimized_graph, 760 input_tensors=self._input_tensors, 761 output_tensors=self._output_tensors, 762 **converter_kwargs) 763 else: 764 result = _toco_convert_graph_def( 765 input_data=optimized_graph, 766 input_arrays_with_shape=self._input_arrays_with_shape, 767 output_arrays=self._output_arrays, 768 **converter_kwargs) 769 770 if self.representative_dataset and post_training_optimize: 771 calibrate_quantize = _calibrator.Calibrator(result) 772 result = calibrate_quantize.calibrate_and_quantize( 773 self.representative_dataset.input_gen) 774 775 return result 776 777 def get_input_arrays(self): 778 """Returns a list of the names of the input tensors. 779 780 Returns: 781 List of strings. 782 """ 783 if self._has_valid_tensors(): 784 return [_tensor_name(tensor) for tensor in self._input_tensors] 785 else: 786 return [name for name, _ in self._input_arrays_with_shape] 787 788 def _has_valid_tensors(self): 789 """Checks if the input and output tensors have been initialized. 790 791 Returns: 792 Bool. 793 """ 794 return self._input_tensors and self._output_tensors 795 796 def _set_batch_size(self, batch_size): 797 """Sets the first dimension of the input tensor to `batch_size`. 798 799 Args: 800 batch_size: Batch size for the model. Replaces the first dimension of an 801 input size array if undefined. (default 1) 802 803 Raises: 804 ValueError: input_tensor is not defined. 805 """ 806 if not self._has_valid_tensors(): 807 raise ValueError("The batch size cannot be set for this model. Please " 808 "use input_shapes parameter.") 809 810 for tensor in self._input_tensors: 811 shape = tensor.shape.as_list() 812 shape[0] = batch_size 813 tensor.set_shape(shape) 814 815 816@_tf_export(v1=["lite.TocoConverter"]) 817class TocoConverter(object): 818 """Convert a TensorFlow model into `output_format` using TOCO. 819 820 This class has been deprecated. Please use `lite.TFLiteConverter` instead. 821 """ 822 823 @classmethod 824 @_deprecation.deprecated(None, 825 "Use `lite.TFLiteConverter.from_session` instead.") 826 def from_session(cls, sess, input_tensors, output_tensors): 827 """Creates a TocoConverter class from a TensorFlow Session.""" 828 return TFLiteConverter.from_session(sess, input_tensors, output_tensors) 829 830 @classmethod 831 @_deprecation.deprecated( 832 None, "Use `lite.TFLiteConverter.from_frozen_graph` instead.") 833 def from_frozen_graph(cls, 834 graph_def_file, 835 input_arrays, 836 output_arrays, 837 input_shapes=None): 838 """Creates a TocoConverter class from a file containing a frozen graph.""" 839 return TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, 840 output_arrays, input_shapes) 841 842 @classmethod 843 @_deprecation.deprecated( 844 None, "Use `lite.TFLiteConverter.from_saved_model` instead.") 845 def from_saved_model(cls, 846 saved_model_dir, 847 input_arrays=None, 848 input_shapes=None, 849 output_arrays=None, 850 tag_set=None, 851 signature_key=None): 852 """Creates a TocoConverter class from a SavedModel.""" 853 return TFLiteConverter.from_saved_model(saved_model_dir, input_arrays, 854 input_shapes, output_arrays, 855 tag_set, signature_key) 856 857 @classmethod 858 @_deprecation.deprecated( 859 None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.") 860 def from_keras_model_file(cls, 861 model_file, 862 input_arrays=None, 863 input_shapes=None, 864 output_arrays=None): 865 """Creates a TocoConverter class from a tf.keras model file.""" 866 return TFLiteConverter.from_keras_model_file(model_file, input_arrays, 867 input_shapes, output_arrays) 868 869 870def _is_frozen_graph(sess): 871 """Determines if the graph is frozen. 872 873 Determines if a graph has previously been frozen by checking for any 874 operations of type Variable*. If variables are found, the graph is not frozen. 875 876 Args: 877 sess: TensorFlow Session. 878 879 Returns: 880 Bool. 881 """ 882 for op in sess.graph.get_operations(): 883 if op.type.startswith("Variable") or op.type.endswith("VariableOp"): 884 return False 885 return True 886 887 888def _freeze_graph(sess, output_tensors): 889 """Returns a frozen GraphDef. 890 891 Freezes a graph with Variables in it. Otherwise the existing GraphDef is 892 returned. 893 894 Args: 895 sess: TensorFlow Session. 896 output_tensors: List of output tensors (only .name is used from this). 897 898 Returns: 899 Frozen GraphDef. 900 """ 901 if not _is_frozen_graph(sess): 902 output_arrays = [_tensor_name(tensor) for tensor in output_tensors] 903 return _tf_graph_util.convert_variables_to_constants( 904 sess, sess.graph_def, output_arrays) 905 else: 906 return sess.graph_def 907