1# Lint as: python2, python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Functions used by multiple converter files.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import copy 23import datetime 24import sys 25 26from absl import logging 27import six 28from six.moves import range 29 30import flatbuffers 31from tensorflow.core.protobuf import config_pb2 as _config_pb2 32from tensorflow.core.protobuf import graph_debug_info_pb2 33from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 34from tensorflow.lite.python import schema_py_generated as schema_fb 35from tensorflow.lite.python import schema_util 36from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util 37from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs 38from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes 39from tensorflow.lite.toco import types_pb2 as _types_pb2 40from tensorflow.python.eager import function 41from tensorflow.python.framework import convert_to_constants as _convert_to_constants 42from tensorflow.python.framework import dtypes 43from tensorflow.python.framework import error_interpolation as _error_interpolation 44from tensorflow.python.framework import graph_util as tf_graph_util 45from tensorflow.python.grappler import tf_optimizer 46from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph 47 48# Keras functions used by TFLite 49model_input_signature = _tflite_keras_util.model_input_signature 50trace_model_call = _tflite_keras_util.trace_model_call 51 52# Map of tf.dtypes to TFLite types_flag_pb2. 53_MAP_TF_TO_TFLITE_TYPES = { 54 dtypes.float32: _types_pb2.FLOAT, 55 dtypes.float16: _types_pb2.FLOAT16, 56 dtypes.int32: _types_pb2.INT32, 57 dtypes.uint8: _types_pb2.QUANTIZED_UINT8, 58 dtypes.int64: _types_pb2.INT64, 59 dtypes.uint64: _types_pb2.UINT64, 60 dtypes.string: _types_pb2.STRING, 61 dtypes.bool: _types_pb2.BOOL, 62 dtypes.int16: _types_pb2.QUANTIZED_INT16, 63 dtypes.complex64: _types_pb2.COMPLEX64, 64 dtypes.int8: _types_pb2.INT8, 65 dtypes.float64: _types_pb2.FLOAT64, 66 dtypes.complex128: _types_pb2.COMPLEX128, 67 dtypes.resource: _types_pb2.RESOURCE, 68 dtypes.variant: _types_pb2.VARIANT, 69 dtypes.uint32: _types_pb2.UINT32, 70} 71 72_MAP_TFLITE_ENUM_TO_TF_TYPES = { 73 0: dtypes.float32, 74 1: dtypes.float16, 75 2: dtypes.int32, 76 3: dtypes.uint8, 77 4: dtypes.int64, 78 5: dtypes.string, 79 6: dtypes.bool, 80 7: dtypes.int16, 81 8: dtypes.complex64, 82 9: dtypes.int8, 83 10: dtypes.float64, 84 11: dtypes.complex128, 85 16: dtypes.uint32, 86} 87 88_TFLITE_FILE_IDENTIFIER = b"TFL3" 89 90_MAP_QUANT_TO_IO_TYPES = { 91 dtypes.int8: {dtypes.int8, dtypes.uint8}, 92 dtypes.int16: {dtypes.int16}, 93} 94 95 96def convert_dtype_to_tflite_type(tf_dtype): 97 """Converts tf.dtype to TFLite proto type. 98 99 Args: 100 tf_dtype: tf.dtype 101 102 Raises: 103 ValueError: Unsupported tf.dtype. 104 105 Returns: 106 types_flag_pb2. 107 """ 108 result = _MAP_TF_TO_TFLITE_TYPES.get(tf_dtype) 109 if result is None: 110 raise ValueError("Unsupported tf.dtype {0}".format(tf_dtype)) 111 return result 112 113 114def _convert_tflite_enum_type_to_tf_type(tflite_enum_type): 115 """Converts tflite enum type (eg: 0) to tf type (eg: tf.float32). 116 117 Args: 118 tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32) 119 120 Raises: 121 ValueError: If an invalid tflite enum type is provided. 122 123 Returns: 124 tf type (eg: tf.float32) 125 """ 126 tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type) 127 if tf_type is None: 128 raise ValueError( 129 "Unsupported enum {}. The valid map of enum to tf types is : {}" 130 .format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES)) 131 return tf_type 132 133 134def get_tf_type_name(tf_type): 135 """Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32").""" 136 return "tf." + tf_type.name if tf_type else None 137 138 139def get_tensor_name(tensor): 140 """Returns name of the input tensor. 141 142 Args: 143 tensor: tf.Tensor 144 145 Returns: 146 str 147 """ 148 parts = six.ensure_str(tensor.name).split(":") 149 if len(parts) > 2: 150 raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format( 151 len(parts) - 1)) 152 153 # To be consistent with the tensor naming scheme in tensorflow, we need 154 # drop the ':0' suffix for the first tensor. 155 if len(parts) > 1 and parts[1] != "0": 156 return tensor.name 157 return parts[0] 158 159 160def get_tensors_from_tensor_names(graph, tensor_names): 161 """Gets the Tensors associated with the `tensor_names` in the provided graph. 162 163 Args: 164 graph: TensorFlow Graph. 165 tensor_names: List of strings that represent names of tensors in the graph. 166 167 Returns: 168 A list of Tensor objects in the same order the names are provided. 169 170 Raises: 171 ValueError: 172 tensor_names contains an invalid tensor name. 173 """ 174 # Get the list of all of the tensors. 175 tensor_name_to_tensor = {} 176 for op in graph.get_operations(): 177 for tensor in op.values(): 178 tensor_name_to_tensor[get_tensor_name(tensor)] = tensor 179 180 # Get the tensors associated with tensor_names. 181 tensors = [] 182 invalid_tensors = [] 183 for name in tensor_names: 184 if not isinstance(name, six.string_types): 185 raise ValueError("Invalid type for a tensor name in the provided graph. " 186 "Expected type for a tensor name is 'str', instead got " 187 "type '{}' for tensor name '{}'".format( 188 type(name), name)) 189 190 tensor = tensor_name_to_tensor.get(name) 191 if tensor is None: 192 invalid_tensors.append(name) 193 else: 194 tensors.append(tensor) 195 196 # Throw ValueError if any user input names are not valid tensors. 197 if invalid_tensors: 198 raise ValueError("Invalid tensors '{}' were found.".format( 199 ",".join(invalid_tensors))) 200 return tensors 201 202 203def set_tensor_shapes(tensors, shapes): 204 """Sets Tensor shape for each tensor if the shape is defined. 205 206 Args: 207 tensors: TensorFlow ops.Tensor. 208 shapes: Dict of strings representing input tensor names to list of 209 integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). 210 211 Raises: 212 ValueError: 213 `shapes` contains an invalid tensor. 214 `shapes` contains an invalid shape for a valid tensor. 215 """ 216 if shapes: 217 tensor_names_to_tensor = { 218 get_tensor_name(tensor): tensor for tensor in tensors 219 } 220 for name, shape in shapes.items(): 221 if name not in tensor_names_to_tensor: 222 raise ValueError("Invalid tensor \'{}\' found in tensor shapes " 223 "map.".format(name)) 224 if shape is not None: 225 tensor = tensor_names_to_tensor[name] 226 try: 227 tensor.set_shape(shape) 228 except ValueError as error: 229 message = ("The shape of tensor '{0}' cannot be changed from {1} to " 230 "{2}. {3}".format(name, tensor.shape, shape, str(error))) 231 raise ValueError(message) 232 233 234def get_grappler_config(optimizers_list): 235 """Creates a tf.compat.v1.ConfigProto for configuring Grappler. 236 237 Args: 238 optimizers_list: List of strings that represents the list of optimizers. 239 240 Returns: 241 tf.ConfigProto. 242 """ 243 config = _config_pb2.ConfigProto() 244 rewrite_options = config.graph_options.rewrite_options 245 for optimizer in optimizers_list: 246 rewrite_options.optimizers.append(optimizer) 247 return config 248 249 250def run_graph_optimizations(graph_def, 251 input_arrays, 252 output_arrays, 253 config, 254 graph=None): 255 """Apply standard TensorFlow optimizations to the graph_def. 256 257 Args: 258 graph_def: Frozen GraphDef to be optimized. 259 input_arrays: List of arrays that are considered inputs of the graph. 260 output_arrays: List of arrays that are considered outputs of the graph. 261 config: tf.ConfigProto. 262 graph: TensorFlow Graph. Required when Eager mode is enabled. (default None) 263 264 Returns: 265 A new, optimized GraphDef. 266 """ 267 meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph) 268 269 signature = _meta_graph_pb2.SignatureDef() 270 for array in input_arrays: 271 signature.inputs[array.name].name = array.name 272 signature.inputs[array.name].dtype = array.dtype.as_datatype_enum 273 signature.inputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto()) 274 275 for array in output_arrays: 276 signature.outputs[array.name].name = array.name 277 signature.outputs[array.name].dtype = array.dtype.as_datatype_enum 278 signature.outputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto()) 279 280 meta_graph.signature_def["not_used_key"].CopyFrom(signature) 281 282 # We need to add a collection called 'train_op' so that grappler 283 # knows what the outputs are. 284 fetch_collection = _meta_graph_pb2.CollectionDef() 285 for array in input_arrays + output_arrays: 286 fetch_collection.node_list.value.append(array.name) 287 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) 288 289 return tf_optimizer.OptimizeGraph(config, meta_graph) 290 291 292def _convert_op_hints_if_present(sess, graph_def, output_tensors, 293 hinted_outputs_nodes): 294 if is_frozen_graph(sess): 295 raise ValueError("Try to convert op hints, needs unfrozen graph.") 296 output_arrays = [get_tensor_name(tensor) for tensor in output_tensors] 297 graph_def = tf_graph_util.convert_variables_to_constants( 298 sess, graph_def, output_arrays + hinted_outputs_nodes) 299 graph_def = convert_op_hints_to_stubs(graph_def=graph_def) 300 return graph_def 301 302 303def freeze_graph(sess, input_tensors, output_tensors): 304 """Returns a frozen GraphDef. 305 306 Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the 307 existing GraphDef is returned. The Grappler pass is only run on models that 308 are frozen in order to inline the functions in the graph. 309 If OpHints is present, it will try to convert the OpHint graph. 310 311 Args: 312 sess: TensorFlow Session. 313 input_tensors: List of input tensors. 314 output_tensors: List of output tensors (only .name is used from this). 315 316 Returns: 317 Frozen GraphDef. 318 """ 319 # Runs a Grappler pass in order to inline any functions in the graph. 320 # Asides from inlining any simple function, Grappler will also try to lower 321 # while loop into switch merge representation which is undesired for Ophints, 322 # so we simply remove those attributes to prevent Grappler from doing so. 323 graph_def = _convert_to_constants.disable_lower_using_switch_merge( 324 sess.graph_def) 325 config = get_grappler_config(["function"]) 326 graph_def = run_graph_optimizations( 327 graph_def, input_tensors, output_tensors, config, graph=sess.graph) 328 329 # If ophints are present, just convert them. 330 hinted_outputs_nodes = find_all_hinted_output_nodes(sess) 331 if hinted_outputs_nodes: 332 return _convert_op_hints_if_present(sess, graph_def, output_tensors, 333 hinted_outputs_nodes) 334 335 if not is_frozen_graph(sess): 336 output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors] 337 return tf_graph_util.convert_variables_to_constants(sess, graph_def, 338 output_node_names) 339 else: 340 return sess.graph_def 341 342 343def is_frozen_graph(sess): 344 """Determines if the graph is frozen. 345 346 Determines if a graph has previously been frozen by checking for any 347 operations of type Variable*. If variables are found, the graph is not frozen. 348 349 Args: 350 sess: TensorFlow Session. 351 352 Returns: 353 Bool. 354 """ 355 for op in sess.graph.get_operations(): 356 if six.ensure_str(op.type).startswith("Variable") or six.ensure_str( 357 op.type).endswith("VariableOp"): 358 return False 359 return True 360 361 362def build_debug_info_func(original_graph): 363 """Returns a method to retrieve the `GraphDebugInfo` from the original graph. 364 365 Args: 366 original_graph: The original `Graph` containing all the op stack traces. 367 368 Returns: 369 A function which retrieves the stack traces from the original graph and 370 converts them to a `GraphDebugInfo` for a given set of nodes. 371 """ 372 373 def f(original_nodes): 374 """Function to create `GraphDebugInfo` for the given `original_nodes`.""" 375 if not original_graph: 376 return None 377 # For the given nodes, gets all the op definitions in the original graph. 378 useful_ops = [] 379 for func, name in original_nodes: 380 try: 381 if not func: 382 useful_ops.append((func, original_graph.get_operation_by_name(name))) 383 else: 384 sub_func = original_graph._get_function(func) # pylint: disable=protected-access 385 if isinstance(sub_func, function._EagerDefinedFunction): # pylint: disable=protected-access 386 useful_ops.append( 387 (func, sub_func.graph.get_operation_by_name(name))) 388 else: 389 sys.stderr.write( 390 "Use '@tf.function' or '@defun' to decorate the function.\n") 391 continue 392 except KeyError: 393 # New node created by graph optimizer. No stack trace from source code. 394 continue 395 # Convert all the op definitions to stack traces in terms of GraphDebugInfo. 396 return _error_interpolation.create_graph_debug_info_def(useful_ops) 397 398 return f 399 400 401def convert_debug_info_func(saved_debug_info): 402 """Returns a method to retrieve the `GraphDebugInfo` from the original graph. 403 404 Args: 405 saved_debug_info: The `GraphDebugInfo` containing all the debug info. 406 407 Returns: 408 A function which retrieves the stack traces from the original graph and 409 converts them to a `GraphDebugInfo` for a given set of nodes. 410 """ 411 412 def f(original_nodes): 413 """Function to create `GraphDebugInfo` for the given `original_nodes`.""" 414 if not saved_debug_info: 415 return None 416 417 output_debug_info = graph_debug_info_pb2.GraphDebugInfo() 418 # All the files are copied over, so the index wouldn't be changed. 419 output_debug_info.files[:] = saved_debug_info.files 420 # We only copy over the debug info for the input nodes 421 for func, node in original_nodes: 422 debug_key = node + "@" + func 423 output_debug_info.traces[debug_key].CopyFrom( 424 saved_debug_info.traces[debug_key]) 425 return output_debug_info 426 427 return f 428 429 430def get_debug_info(nodes_to_debug_info_func, converted_graph): 431 """Returns the debug info for the original nodes in the `converted_graph`. 432 433 Args: 434 nodes_to_debug_info_func: The method to collect the op debug info for the 435 nodes. 436 converted_graph: A `GraphDef` after optimization and transformation. 437 438 Returns: 439 `GraphDebugInfo` for all the original nodes in `converted_graph`. 440 """ 441 if not nodes_to_debug_info_func: 442 return None 443 444 # Collect all the debug info nodes from the converted_graph 445 original_nodes = set() 446 for node in converted_graph.node: 447 debug_nodes = node.experimental_debug_info.original_node_names 448 debug_funcs = node.experimental_debug_info.original_func_names 449 # If the `original_node_names` are empty, uses the node name directly. 450 if not debug_nodes: 451 original_nodes.add(("", node.name)) 452 else: 453 for i in range(len(debug_nodes)): 454 debug_func = "" if i >= len(debug_funcs) else debug_funcs[i] 455 original_nodes.add((debug_func, debug_nodes[i])) 456 457 # Convert the nodes to the debug info proto object. 458 return nodes_to_debug_info_func(original_nodes) 459 460 461def convert_bytes_to_c_source(data, 462 array_name, 463 max_line_width=80, 464 include_guard=None, 465 include_path=None, 466 use_tensorflow_license=False): 467 """Returns strings representing a C constant array containing `data`. 468 469 Args: 470 data: Byte array that will be converted into a C constant. 471 array_name: String to use as the variable name for the constant array. 472 max_line_width: The longest line length, for formatting purposes. 473 include_guard: Name to use for the include guard macro definition. 474 include_path: Optional path to include in the source file. 475 use_tensorflow_license: Whether to include the standard TensorFlow Apache2 476 license in the generated files. 477 478 Returns: 479 Text that can be compiled as a C source file to link in the data as a 480 literal array of values. 481 Text that can be used as a C header file to reference the literal array. 482 """ 483 484 starting_pad = " " 485 array_lines = [] 486 array_line = starting_pad 487 for value in bytearray(data): 488 if (len(array_line) + 4) > max_line_width: 489 array_lines.append(array_line + "\n") 490 array_line = starting_pad 491 array_line += " 0x%02x," % (value) 492 if len(array_line) > len(starting_pad): 493 array_lines.append(array_line + "\n") 494 array_values = "".join(array_lines) 495 496 if include_guard is None: 497 include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_" 498 499 if include_path is not None: 500 include_line = "#include \"{include_path}\"\n".format( 501 include_path=include_path) 502 else: 503 include_line = "" 504 505 if use_tensorflow_license: 506 license_text = """ 507/* Copyright {year} The TensorFlow Authors. All Rights Reserved. 508 509Licensed under the Apache License, Version 2.0 (the "License"); 510you may not use this file except in compliance with the License. 511You may obtain a copy of the License at 512 513 http://www.apache.org/licenses/LICENSE-2.0 514 515Unless required by applicable law or agreed to in writing, software 516distributed under the License is distributed on an "AS IS" BASIS, 517WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 518See the License for the specific language governing permissions and 519limitations under the License. 520==============================================================================*/ 521""".format(year=datetime.date.today().year) 522 else: 523 license_text = "" 524 525 source_template = """{license_text} 526// This is a TensorFlow Lite model file that has been converted into a C data 527// array using the tensorflow.lite.util.convert_bytes_to_c_source() function. 528// This form is useful for compiling into a binary for devices that don't have a 529// file system. 530 531{include_line} 532// We need to keep the data array aligned on some architectures. 533#ifdef __has_attribute 534#define HAVE_ATTRIBUTE(x) __has_attribute(x) 535#else 536#define HAVE_ATTRIBUTE(x) 0 537#endif 538#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) 539#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4))) 540#else 541#define DATA_ALIGN_ATTRIBUTE 542#endif 543 544const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{ 545{array_values}}}; 546const int {array_name}_len = {array_length}; 547""" 548 549 source_text = source_template.format( 550 array_name=array_name, 551 array_length=len(data), 552 array_values=array_values, 553 license_text=license_text, 554 include_line=include_line) 555 556 header_template = """ 557{license_text} 558 559// This is a TensorFlow Lite model file that has been converted into a C data 560// array using the tensorflow.lite.util.convert_bytes_to_c_source() function. 561// This form is useful for compiling into a binary for devices that don't have a 562// file system. 563 564#ifndef {include_guard} 565#define {include_guard} 566 567extern const unsigned char {array_name}[]; 568extern const int {array_name}_len; 569 570#endif // {include_guard} 571""" 572 573 header_text = header_template.format( 574 array_name=array_name, 575 include_guard=include_guard, 576 license_text=license_text) 577 578 return source_text, header_text 579 580 581def _convert_model_from_bytearray_to_object(model_bytearray): 582 """Converts a tflite model from a bytearray into a parsable object.""" 583 model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0) 584 model_object = schema_fb.ModelT.InitFromObj(model_object) 585 model_object = copy.deepcopy(model_object) 586 model_object.subgraphs[0].inputs[0] = model_object.subgraphs[0].inputs[0] 587 return model_object 588 589 590def _convert_model_from_object_to_bytearray(model_object): 591 """Converts a tflite model from a parsable object into a bytearray.""" 592 # Initial size of the buffer, which will grow automatically if needed 593 builder = flatbuffers.Builder(1024) 594 model_offset = model_object.Pack(builder) 595 builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER) 596 return bytes(builder.Output()) 597 598 599def _remove_tensors_from_model(model, remove_tensors_idxs): 600 """Remove tensors from model.""" 601 if not remove_tensors_idxs: 602 return 603 if len(model.subgraphs) > 1: 604 raise ValueError("Model must only have one subgraph. Instead, it has " 605 "{} subgraphs.".format(len(model.subgraphs))) 606 subgraph = model.subgraphs[0] 607 tensors = subgraph.tensors 608 operators = subgraph.operators 609 610 logging.debug("Removing tensors at indices : %s", remove_tensors_idxs) 611 # An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an 612 # exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]). 613 if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs): 614 logging.debug("Removing tensors only at the end of the tensor list") 615 del tensors[min(remove_tensors_idxs):] 616 else: 617 logging.debug("Removing tensors requires updating the model") 618 # Map the old tensor indices to new tensor indices 619 d_old_to_new_tensors = {} 620 left_shift_by = 0 621 for idx in range(len(tensors)): 622 if idx in remove_tensors_idxs: 623 left_shift_by += 1 624 else: 625 d_old_to_new_tensors[idx] = idx - left_shift_by 626 logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__()) 627 # Update tensor indices referenced throughout the model 628 def update_tensors(tensor_idxs): 629 for i, ti in enumerate(tensor_idxs): 630 tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1) 631 update_tensors(subgraph.inputs) 632 update_tensors(subgraph.outputs) 633 for op in operators: 634 update_tensors(op.inputs) 635 update_tensors(op.outputs) 636 # Delete the tensors 637 for idx in sorted(remove_tensors_idxs, reverse=True): 638 tensors.pop(idx) 639 logging.debug("Removed tensors marked for deletion") 640 641 642def _modify_model_input_type(model, inference_input_type=dtypes.float32): 643 """Modify model input type.""" 644 645 if inference_input_type == dtypes.float32: 646 return 647 648 subgraph = model.subgraphs[0] 649 tensors = subgraph.tensors 650 operators = subgraph.operators 651 652 # Find all quantize operators 653 quant_opcode_idxs = [] 654 for idx, opcode in enumerate(model.operatorCodes): 655 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 656 if builtin_code == schema_fb.BuiltinOperator.QUANTIZE: 657 quant_opcode_idxs.append(idx) 658 if operators and not quant_opcode_idxs: 659 for input_idx in subgraph.inputs: 660 input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type) 661 if input_type == dtypes.float32: 662 raise ValueError("Model input is not dequantized.") 663 # None of the inputs have float32, then they must be int16, int8, or bool 664 return 665 666 # Validate that the model input is quantized 667 input_quant_ops = [] 668 for op in operators: 669 # Find operators that quantize model input 670 if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs: 671 float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]] 672 # If found, validate that the operator's input type is float 673 float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type) 674 if float_type != dtypes.float32: 675 if float_type == inference_input_type: 676 continue 677 else: 678 raise ValueError( 679 "Initial model input type must be tf.float32. Expected type for " 680 "tensor with name '{}' is tf.float32, instead type is {}".format( 681 float_tensor.name, get_tf_type_name(float_type))) 682 # If found, validate that the operator output is quantized and compatible 683 # with the final model input type 684 quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type) 685 if quant_type not in _MAP_QUANT_TO_IO_TYPES: 686 raise ValueError( 687 "Initial model input is not quantized. Expected type for " 688 "tensor with name '{}' should be in {}, instead type is {}".format( 689 quant_tensor.name, 690 tuple(get_tf_type_name(t) for t in 691 _MAP_QUANT_TO_IO_TYPES.keys()), 692 get_tf_type_name(quant_type))) 693 else: 694 inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] 695 if inference_input_type not in inference_io_types: 696 raise ValueError( 697 "Unsupported `inference_input_type` value. Expected to be in " 698 "{}, instead got {}.".format( 699 tuple(get_tf_type_name(t) for t in inference_io_types), 700 get_tf_type_name(inference_input_type))) 701 input_quant_ops.append(op) 702 703 if len(subgraph.inputs) != len(input_quant_ops): 704 logging.warning( 705 "For model inputs containing unsupported operations which cannot be " 706 "quantized, the `inference_input_type` attribute will default to the " 707 "original type." 708 ) 709 710 # Modify model input type 711 if inference_input_type == dtypes.uint8: 712 # Change quant op (float to int8) to quant op (uint8 to int8) 713 for op in input_quant_ops: 714 int8_quantization = tensors[op.outputs[0]].quantization 715 uint8_quantization = schema_fb.QuantizationParametersT() 716 uint8_quantization.scale = [int8_quantization.scale[0]] 717 uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] 718 tensors[op.inputs[0]].quantization = uint8_quantization 719 tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8 720 elif inference_input_type in _MAP_QUANT_TO_IO_TYPES: 721 # Remove the inputs and the quant operator 722 remove_tensors_idxs = set() 723 for op in input_quant_ops: 724 subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0] 725 remove_tensors_idxs.add(op.inputs[0]) 726 operators.remove(op) 727 # Remove tensors marked for deletion. 728 _remove_tensors_from_model(model, remove_tensors_idxs) 729 else: 730 raise ValueError( 731 "Unsupported `inference_input_type` value {}.".format( 732 get_tf_type_name(inference_input_type))) 733 734 735def _modify_model_output_type(model, inference_output_type=dtypes.float32): 736 """Modify model output type.""" 737 738 if inference_output_type == dtypes.float32: 739 return 740 741 subgraph = model.subgraphs[0] 742 tensors = subgraph.tensors 743 operators = subgraph.operators 744 745 # Find all dequantize operators 746 dequant_opcode_idxs = [] 747 for idx, opcode in enumerate(model.operatorCodes): 748 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 749 if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE: 750 dequant_opcode_idxs.append(idx) 751 if operators and not dequant_opcode_idxs: 752 for output in subgraph.outputs: 753 output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type) 754 if output_type == dtypes.float32: 755 raise ValueError("Model output is not dequantized.") 756 # None of the outputs have float32, then they must be int16, int8, or bool 757 return 758 759 # Validate that the model output is dequantized 760 output_dequant_ops = [] 761 for op in operators: 762 # Find operators that dequantize model output 763 if op.opcodeIndex in dequant_opcode_idxs and \ 764 op.outputs[0] in subgraph.outputs: 765 # If found, validate that the operator's output type is float 766 quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]] 767 float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type) 768 if float_type != dtypes.float32: 769 if float_type == inference_output_type: 770 continue 771 else: 772 raise ValueError( 773 "Initial model output type must be tf.float32. Expected type for " 774 "tensor with name '{}' is tf.float32, instead type is {}".format( 775 float_tensor.name, get_tf_type_name(float_type))) 776 # If found, validate that the operator input is quantized and compatible 777 # with the final model output type 778 quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type) 779 if quant_type not in _MAP_QUANT_TO_IO_TYPES: 780 raise ValueError( 781 "Initial model output is not dequantized. Expected type for " 782 "tensor with name '{}' should be in {}, instead type is {}".format( 783 quant_tensor.name, 784 tuple(get_tf_type_name(t) for t in 785 _MAP_QUANT_TO_IO_TYPES.keys()), 786 get_tf_type_name(quant_type))) 787 else: 788 inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] 789 if inference_output_type not in inference_io_types: 790 raise ValueError( 791 "Unsupported `inference_output_type` value. Expected to be in " 792 "{}, instead got {}.".format( 793 tuple(get_tf_type_name(t) for t in inference_io_types), 794 get_tf_type_name(inference_output_type))) 795 output_dequant_ops.append(op) 796 797 if len(subgraph.outputs) != len(output_dequant_ops): 798 logging.warning( 799 "For model outputs containing unsupported operations which cannot be " 800 "quantized, the `inference_output_type` attribute will default to the " 801 "original type." 802 ) 803 804 # Modify model output type 805 if inference_output_type == dtypes.uint8: 806 # Find a quantize operator 807 quant_opcode_idx = -1 808 for idx, opcode in enumerate(model.operatorCodes): 809 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 810 if builtin_code == schema_fb.BuiltinOperator.QUANTIZE: 811 quant_opcode_idx = idx 812 break 813 # Create a quantize operator, if none exist 814 if quant_opcode_idx == -1: 815 quant_op = schema_fb.OperatorCodeT() 816 quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE 817 quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE 818 model.operatorCodes.append(quant_op) 819 quant_opcode_idx = len(model.operatorCodes) - 1 820 # Change dequant op (int8 to float) to quant op (int8 to uint8) 821 for op in output_dequant_ops: 822 op.opcodeIndex = quant_opcode_idx 823 int8_quantization = tensors[op.inputs[0]].quantization 824 uint8_quantization = schema_fb.QuantizationParametersT() 825 uint8_quantization.scale = [int8_quantization.scale[0]] 826 uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] 827 tensors[op.outputs[0]].quantization = uint8_quantization 828 tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8 829 elif inference_output_type in _MAP_QUANT_TO_IO_TYPES: 830 # Remove the outputs and the dequant operator 831 remove_tensors_idxs = set() 832 for op in output_dequant_ops: 833 subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0] 834 remove_tensors_idxs.add(op.outputs[0]) 835 operators.remove(op) 836 # Remove tensors marked for deletion. 837 _remove_tensors_from_model(model, remove_tensors_idxs) 838 else: 839 raise ValueError( 840 "Unsupported `inference_output_type` value {}.".format( 841 get_tf_type_name(inference_output_type))) 842 843 844def modify_model_io_type( 845 model, inference_input_type=dtypes.float32, 846 inference_output_type=dtypes.float32): 847 """Modify the input/output type of a tflite model. 848 849 Args: 850 model: A tflite model. 851 inference_input_type: tf.DType representing modified input type. 852 (default tf.float32. If model input is int8 quantized, it must be in 853 {tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized, 854 it must be in {tf.float32, tf.int16}, else it must be tf.float32) 855 inference_output_type: tf.DType representing modified output type. 856 (default tf.float32. If model output is int8 dequantized, it must be in 857 {tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized, 858 it must be in {tf.float32, tf.int16}, else it must be tf.float32) 859 Returns: 860 A tflite model with modified input/output type. 861 862 Raises: 863 ValueError: If `inference_input_type`/`inference_output_type` is unsupported 864 or a supported integer type is specified for a model whose input/output is 865 not quantized/dequantized. 866 RuntimeError: If the modification was unsuccessful. 867 868 """ 869 if inference_input_type == dtypes.float32 and \ 870 inference_output_type == dtypes.float32: 871 return model 872 873 model_object = _convert_model_from_bytearray_to_object(model) 874 875 if len(model_object.subgraphs) > 1: 876 raise ValueError("Model must only have one subgraph. Instead, it has " 877 "{} subgraphs.".format(len(model_object.subgraphs))) 878 879 _modify_model_input_type(model_object, inference_input_type) 880 881 _modify_model_output_type(model_object, inference_output_type) 882 883 return _convert_model_from_object_to_bytearray(model_object) 884