1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Functions used by multiple converter files."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import copy
23import datetime
24import sys
25
26from absl import logging
27import six
28from six.moves import range
29
30import flatbuffers
31from tensorflow.core.protobuf import config_pb2 as _config_pb2
32from tensorflow.core.protobuf import graph_debug_info_pb2
33from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
34from tensorflow.lite.python import schema_py_generated as schema_fb
35from tensorflow.lite.python import schema_util
36from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util
37from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
38from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
39from tensorflow.lite.toco import types_pb2 as _types_pb2
40from tensorflow.python.eager import function
41from tensorflow.python.framework import convert_to_constants as _convert_to_constants
42from tensorflow.python.framework import dtypes
43from tensorflow.python.framework import error_interpolation as _error_interpolation
44from tensorflow.python.framework import graph_util as tf_graph_util
45from tensorflow.python.grappler import tf_optimizer
46from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
47
48# Keras functions used by TFLite
49model_input_signature = _tflite_keras_util.model_input_signature
50trace_model_call = _tflite_keras_util.trace_model_call
51
52# Map of tf.dtypes to TFLite types_flag_pb2.
53_MAP_TF_TO_TFLITE_TYPES = {
54    dtypes.float32: _types_pb2.FLOAT,
55    dtypes.float16: _types_pb2.FLOAT16,
56    dtypes.int32: _types_pb2.INT32,
57    dtypes.uint8: _types_pb2.QUANTIZED_UINT8,
58    dtypes.int64: _types_pb2.INT64,
59    dtypes.uint64: _types_pb2.UINT64,
60    dtypes.string: _types_pb2.STRING,
61    dtypes.bool: _types_pb2.BOOL,
62    dtypes.int16: _types_pb2.QUANTIZED_INT16,
63    dtypes.complex64: _types_pb2.COMPLEX64,
64    dtypes.int8: _types_pb2.INT8,
65    dtypes.float64: _types_pb2.FLOAT64,
66    dtypes.complex128: _types_pb2.COMPLEX128,
67    dtypes.resource: _types_pb2.RESOURCE,
68    dtypes.variant: _types_pb2.VARIANT,
69    dtypes.uint32: _types_pb2.UINT32,
70}
71
72_MAP_TFLITE_ENUM_TO_TF_TYPES = {
73    0: dtypes.float32,
74    1: dtypes.float16,
75    2: dtypes.int32,
76    3: dtypes.uint8,
77    4: dtypes.int64,
78    5: dtypes.string,
79    6: dtypes.bool,
80    7: dtypes.int16,
81    8: dtypes.complex64,
82    9: dtypes.int8,
83    10: dtypes.float64,
84    11: dtypes.complex128,
85    16: dtypes.uint32,
86}
87
88_TFLITE_FILE_IDENTIFIER = b"TFL3"
89
90_MAP_QUANT_TO_IO_TYPES = {
91    dtypes.int8: {dtypes.int8, dtypes.uint8},
92    dtypes.int16: {dtypes.int16},
93}
94
95
96def convert_dtype_to_tflite_type(tf_dtype):
97  """Converts tf.dtype to TFLite proto type.
98
99  Args:
100    tf_dtype: tf.dtype
101
102  Raises:
103    ValueError: Unsupported tf.dtype.
104
105  Returns:
106    types_flag_pb2.
107  """
108  result = _MAP_TF_TO_TFLITE_TYPES.get(tf_dtype)
109  if result is None:
110    raise ValueError("Unsupported tf.dtype {0}".format(tf_dtype))
111  return result
112
113
114def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
115  """Converts tflite enum type (eg: 0) to tf type (eg: tf.float32).
116
117  Args:
118    tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32)
119
120  Raises:
121    ValueError: If an invalid tflite enum type is provided.
122
123  Returns:
124    tf type (eg: tf.float32)
125  """
126  tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
127  if tf_type is None:
128    raise ValueError(
129        "Unsupported enum {}. The valid map of enum to tf types is : {}"
130        .format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
131  return tf_type
132
133
134def get_tf_type_name(tf_type):
135  """Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
136  return "tf." + tf_type.name if tf_type else None
137
138
139def get_tensor_name(tensor):
140  """Returns name of the input tensor.
141
142  Args:
143    tensor: tf.Tensor
144
145  Returns:
146    str
147  """
148  parts = six.ensure_str(tensor.name).split(":")
149  if len(parts) > 2:
150    raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
151        len(parts) - 1))
152
153  # To be consistent with the tensor naming scheme in tensorflow, we need
154  # drop the ':0' suffix for the first tensor.
155  if len(parts) > 1 and parts[1] != "0":
156    return tensor.name
157  return parts[0]
158
159
160def get_tensors_from_tensor_names(graph, tensor_names):
161  """Gets the Tensors associated with the `tensor_names` in the provided graph.
162
163  Args:
164    graph: TensorFlow Graph.
165    tensor_names: List of strings that represent names of tensors in the graph.
166
167  Returns:
168    A list of Tensor objects in the same order the names are provided.
169
170  Raises:
171    ValueError:
172      tensor_names contains an invalid tensor name.
173  """
174  # Get the list of all of the tensors.
175  tensor_name_to_tensor = {}
176  for op in graph.get_operations():
177    for tensor in op.values():
178      tensor_name_to_tensor[get_tensor_name(tensor)] = tensor
179
180  # Get the tensors associated with tensor_names.
181  tensors = []
182  invalid_tensors = []
183  for name in tensor_names:
184    if not isinstance(name, six.string_types):
185      raise ValueError("Invalid type for a tensor name in the provided graph. "
186                       "Expected type for a tensor name is 'str', instead got "
187                       "type '{}' for tensor name '{}'".format(
188                           type(name), name))
189
190    tensor = tensor_name_to_tensor.get(name)
191    if tensor is None:
192      invalid_tensors.append(name)
193    else:
194      tensors.append(tensor)
195
196  # Throw ValueError if any user input names are not valid tensors.
197  if invalid_tensors:
198    raise ValueError("Invalid tensors '{}' were found.".format(
199        ",".join(invalid_tensors)))
200  return tensors
201
202
203def set_tensor_shapes(tensors, shapes):
204  """Sets Tensor shape for each tensor if the shape is defined.
205
206  Args:
207    tensors: TensorFlow ops.Tensor.
208    shapes: Dict of strings representing input tensor names to list of
209      integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
210
211  Raises:
212    ValueError:
213      `shapes` contains an invalid tensor.
214      `shapes` contains an invalid shape for a valid tensor.
215  """
216  if shapes:
217    tensor_names_to_tensor = {
218        get_tensor_name(tensor): tensor for tensor in tensors
219    }
220    for name, shape in shapes.items():
221      if name not in tensor_names_to_tensor:
222        raise ValueError("Invalid tensor \'{}\' found in tensor shapes "
223                         "map.".format(name))
224      if shape is not None:
225        tensor = tensor_names_to_tensor[name]
226        try:
227          tensor.set_shape(shape)
228        except ValueError as error:
229          message = ("The shape of tensor '{0}' cannot be changed from {1} to "
230                     "{2}. {3}".format(name, tensor.shape, shape, str(error)))
231          raise ValueError(message)
232
233
234def get_grappler_config(optimizers_list):
235  """Creates a tf.compat.v1.ConfigProto for configuring Grappler.
236
237  Args:
238    optimizers_list: List of strings that represents the list of optimizers.
239
240  Returns:
241    tf.ConfigProto.
242  """
243  config = _config_pb2.ConfigProto()
244  rewrite_options = config.graph_options.rewrite_options
245  for optimizer in optimizers_list:
246    rewrite_options.optimizers.append(optimizer)
247  return config
248
249
250def run_graph_optimizations(graph_def,
251                            input_arrays,
252                            output_arrays,
253                            config,
254                            graph=None):
255  """Apply standard TensorFlow optimizations to the graph_def.
256
257  Args:
258    graph_def: Frozen GraphDef to be optimized.
259    input_arrays: List of arrays that are considered inputs of the graph.
260    output_arrays: List of arrays that are considered outputs of the graph.
261    config: tf.ConfigProto.
262    graph: TensorFlow Graph. Required when Eager mode is enabled. (default None)
263
264  Returns:
265    A new, optimized GraphDef.
266  """
267  meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph)
268
269  signature = _meta_graph_pb2.SignatureDef()
270  for array in input_arrays:
271    signature.inputs[array.name].name = array.name
272    signature.inputs[array.name].dtype = array.dtype.as_datatype_enum
273    signature.inputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
274
275  for array in output_arrays:
276    signature.outputs[array.name].name = array.name
277    signature.outputs[array.name].dtype = array.dtype.as_datatype_enum
278    signature.outputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
279
280  meta_graph.signature_def["not_used_key"].CopyFrom(signature)
281
282  # We need to add a collection called 'train_op' so that grappler
283  # knows what the outputs are.
284  fetch_collection = _meta_graph_pb2.CollectionDef()
285  for array in input_arrays + output_arrays:
286    fetch_collection.node_list.value.append(array.name)
287  meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
288
289  return tf_optimizer.OptimizeGraph(config, meta_graph)
290
291
292def _convert_op_hints_if_present(sess, graph_def, output_tensors,
293                                 hinted_outputs_nodes):
294  if is_frozen_graph(sess):
295    raise ValueError("Try to convert op hints, needs unfrozen graph.")
296  output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
297  graph_def = tf_graph_util.convert_variables_to_constants(
298      sess, graph_def, output_arrays + hinted_outputs_nodes)
299  graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
300  return graph_def
301
302
303def freeze_graph(sess, input_tensors, output_tensors):
304  """Returns a frozen GraphDef.
305
306  Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
307  existing GraphDef is returned. The Grappler pass is only run on models that
308  are frozen in order to inline the functions in the graph.
309  If OpHints is present, it will try to convert the OpHint graph.
310
311  Args:
312    sess: TensorFlow Session.
313    input_tensors: List of input tensors.
314    output_tensors: List of output tensors (only .name is used from this).
315
316  Returns:
317    Frozen GraphDef.
318  """
319  # Runs a Grappler pass in order to inline any functions in the graph.
320  # Asides from inlining any simple function, Grappler will also try to lower
321  # while loop into switch merge representation which is undesired for Ophints,
322  # so we simply remove those attributes to prevent Grappler from doing so.
323  graph_def = _convert_to_constants.disable_lower_using_switch_merge(
324      sess.graph_def)
325  config = get_grappler_config(["function"])
326  graph_def = run_graph_optimizations(
327      graph_def, input_tensors, output_tensors, config, graph=sess.graph)
328
329  # If ophints are present, just convert them.
330  hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
331  if hinted_outputs_nodes:
332    return _convert_op_hints_if_present(sess, graph_def, output_tensors,
333                                        hinted_outputs_nodes)
334
335  if not is_frozen_graph(sess):
336    output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors]
337    return tf_graph_util.convert_variables_to_constants(sess, graph_def,
338                                                        output_node_names)
339  else:
340    return sess.graph_def
341
342
343def is_frozen_graph(sess):
344  """Determines if the graph is frozen.
345
346  Determines if a graph has previously been frozen by checking for any
347  operations of type Variable*. If variables are found, the graph is not frozen.
348
349  Args:
350    sess: TensorFlow Session.
351
352  Returns:
353    Bool.
354  """
355  for op in sess.graph.get_operations():
356    if six.ensure_str(op.type).startswith("Variable") or six.ensure_str(
357        op.type).endswith("VariableOp"):
358      return False
359  return True
360
361
362def build_debug_info_func(original_graph):
363  """Returns a method to retrieve the `GraphDebugInfo` from the original graph.
364
365  Args:
366    original_graph: The original `Graph` containing all the op stack traces.
367
368  Returns:
369    A function which retrieves the stack traces from the original graph and
370    converts them to a `GraphDebugInfo` for a given set of nodes.
371  """
372
373  def f(original_nodes):
374    """Function to create `GraphDebugInfo` for the given `original_nodes`."""
375    if not original_graph:
376      return None
377    # For the given nodes, gets all the op definitions in the original graph.
378    useful_ops = []
379    for func, name in original_nodes:
380      try:
381        if not func:
382          useful_ops.append((func, original_graph.get_operation_by_name(name)))
383        else:
384          sub_func = original_graph._get_function(func)  # pylint: disable=protected-access
385          if isinstance(sub_func, function._EagerDefinedFunction):  # pylint: disable=protected-access
386            useful_ops.append(
387                (func, sub_func.graph.get_operation_by_name(name)))
388          else:
389            sys.stderr.write(
390                "Use '@tf.function' or '@defun' to decorate the function.\n")
391            continue
392      except KeyError:
393        # New node created by graph optimizer. No stack trace from source code.
394        continue
395    # Convert all the op definitions to stack traces in terms of GraphDebugInfo.
396    return _error_interpolation.create_graph_debug_info_def(useful_ops)
397
398  return f
399
400
401def convert_debug_info_func(saved_debug_info):
402  """Returns a method to retrieve the `GraphDebugInfo` from the original graph.
403
404  Args:
405    saved_debug_info: The `GraphDebugInfo` containing all the debug info.
406
407  Returns:
408    A function which retrieves the stack traces from the original graph and
409    converts them to a `GraphDebugInfo` for a given set of nodes.
410  """
411
412  def f(original_nodes):
413    """Function to create `GraphDebugInfo` for the given `original_nodes`."""
414    if not saved_debug_info:
415      return None
416
417    output_debug_info = graph_debug_info_pb2.GraphDebugInfo()
418    # All the files are copied over, so the index wouldn't be changed.
419    output_debug_info.files[:] = saved_debug_info.files
420    # We only copy over the debug info for the input nodes
421    for func, node in original_nodes:
422      debug_key = node + "@" + func
423      output_debug_info.traces[debug_key].CopyFrom(
424          saved_debug_info.traces[debug_key])
425    return output_debug_info
426
427  return f
428
429
430def get_debug_info(nodes_to_debug_info_func, converted_graph):
431  """Returns the debug info for the original nodes in the `converted_graph`.
432
433  Args:
434    nodes_to_debug_info_func: The method to collect the op debug info for the
435      nodes.
436    converted_graph: A `GraphDef` after optimization and transformation.
437
438  Returns:
439    `GraphDebugInfo` for all the original nodes in `converted_graph`.
440  """
441  if not nodes_to_debug_info_func:
442    return None
443
444  # Collect all the debug info nodes from the converted_graph
445  original_nodes = set()
446  for node in converted_graph.node:
447    debug_nodes = node.experimental_debug_info.original_node_names
448    debug_funcs = node.experimental_debug_info.original_func_names
449    # If the `original_node_names` are empty, uses the node name directly.
450    if not debug_nodes:
451      original_nodes.add(("", node.name))
452    else:
453      for i in range(len(debug_nodes)):
454        debug_func = "" if i >= len(debug_funcs) else debug_funcs[i]
455        original_nodes.add((debug_func, debug_nodes[i]))
456
457  # Convert the nodes to the debug info proto object.
458  return nodes_to_debug_info_func(original_nodes)
459
460
461def convert_bytes_to_c_source(data,
462                              array_name,
463                              max_line_width=80,
464                              include_guard=None,
465                              include_path=None,
466                              use_tensorflow_license=False):
467  """Returns strings representing a C constant array containing `data`.
468
469  Args:
470    data: Byte array that will be converted into a C constant.
471    array_name: String to use as the variable name for the constant array.
472    max_line_width: The longest line length, for formatting purposes.
473    include_guard: Name to use for the include guard macro definition.
474    include_path: Optional path to include in the source file.
475    use_tensorflow_license: Whether to include the standard TensorFlow Apache2
476      license in the generated files.
477
478  Returns:
479    Text that can be compiled as a C source file to link in the data as a
480    literal array of values.
481    Text that can be used as a C header file to reference the literal array.
482  """
483
484  starting_pad = "   "
485  array_lines = []
486  array_line = starting_pad
487  for value in bytearray(data):
488    if (len(array_line) + 4) > max_line_width:
489      array_lines.append(array_line + "\n")
490      array_line = starting_pad
491    array_line += " 0x%02x," % (value)
492  if len(array_line) > len(starting_pad):
493    array_lines.append(array_line + "\n")
494  array_values = "".join(array_lines)
495
496  if include_guard is None:
497    include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_"
498
499  if include_path is not None:
500    include_line = "#include \"{include_path}\"\n".format(
501        include_path=include_path)
502  else:
503    include_line = ""
504
505  if use_tensorflow_license:
506    license_text = """
507/* Copyright {year} The TensorFlow Authors. All Rights Reserved.
508
509Licensed under the Apache License, Version 2.0 (the "License");
510you may not use this file except in compliance with the License.
511You may obtain a copy of the License at
512
513    http://www.apache.org/licenses/LICENSE-2.0
514
515Unless required by applicable law or agreed to in writing, software
516distributed under the License is distributed on an "AS IS" BASIS,
517WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
518See the License for the specific language governing permissions and
519limitations under the License.
520==============================================================================*/
521""".format(year=datetime.date.today().year)
522  else:
523    license_text = ""
524
525  source_template = """{license_text}
526// This is a TensorFlow Lite model file that has been converted into a C data
527// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
528// This form is useful for compiling into a binary for devices that don't have a
529// file system.
530
531{include_line}
532// We need to keep the data array aligned on some architectures.
533#ifdef __has_attribute
534#define HAVE_ATTRIBUTE(x) __has_attribute(x)
535#else
536#define HAVE_ATTRIBUTE(x) 0
537#endif
538#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__))
539#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4)))
540#else
541#define DATA_ALIGN_ATTRIBUTE
542#endif
543
544const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{
545{array_values}}};
546const int {array_name}_len = {array_length};
547"""
548
549  source_text = source_template.format(
550      array_name=array_name,
551      array_length=len(data),
552      array_values=array_values,
553      license_text=license_text,
554      include_line=include_line)
555
556  header_template = """
557{license_text}
558
559// This is a TensorFlow Lite model file that has been converted into a C data
560// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
561// This form is useful for compiling into a binary for devices that don't have a
562// file system.
563
564#ifndef {include_guard}
565#define {include_guard}
566
567extern const unsigned char {array_name}[];
568extern const int {array_name}_len;
569
570#endif  // {include_guard}
571"""
572
573  header_text = header_template.format(
574      array_name=array_name,
575      include_guard=include_guard,
576      license_text=license_text)
577
578  return source_text, header_text
579
580
581def _convert_model_from_bytearray_to_object(model_bytearray):
582  """Converts a tflite model from a bytearray into a parsable object."""
583  model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
584  model_object = schema_fb.ModelT.InitFromObj(model_object)
585  model_object = copy.deepcopy(model_object)
586  model_object.subgraphs[0].inputs[0] = model_object.subgraphs[0].inputs[0]
587  return model_object
588
589
590def _convert_model_from_object_to_bytearray(model_object):
591  """Converts a tflite model from a parsable object into a bytearray."""
592  # Initial size of the buffer, which will grow automatically if needed
593  builder = flatbuffers.Builder(1024)
594  model_offset = model_object.Pack(builder)
595  builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
596  return bytes(builder.Output())
597
598
599def _remove_tensors_from_model(model, remove_tensors_idxs):
600  """Remove tensors from model."""
601  if not remove_tensors_idxs:
602    return
603  if len(model.subgraphs) > 1:
604    raise ValueError("Model must only have one subgraph. Instead, it has "
605                     "{} subgraphs.".format(len(model.subgraphs)))
606  subgraph = model.subgraphs[0]
607  tensors = subgraph.tensors
608  operators = subgraph.operators
609
610  logging.debug("Removing tensors at indices : %s", remove_tensors_idxs)
611  # An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an
612  # exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]).
613  if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs):
614    logging.debug("Removing tensors only at the end of the tensor list")
615    del tensors[min(remove_tensors_idxs):]
616  else:
617    logging.debug("Removing tensors requires updating the model")
618    # Map the old tensor indices to new tensor indices
619    d_old_to_new_tensors = {}
620    left_shift_by = 0
621    for idx in range(len(tensors)):
622      if idx in remove_tensors_idxs:
623        left_shift_by += 1
624      else:
625        d_old_to_new_tensors[idx] = idx - left_shift_by
626    logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__())
627    # Update tensor indices referenced throughout the model
628    def update_tensors(tensor_idxs):
629      for i, ti in enumerate(tensor_idxs):
630        tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1)
631    update_tensors(subgraph.inputs)
632    update_tensors(subgraph.outputs)
633    for op in operators:
634      update_tensors(op.inputs)
635      update_tensors(op.outputs)
636    # Delete the tensors
637    for idx in sorted(remove_tensors_idxs, reverse=True):
638      tensors.pop(idx)
639    logging.debug("Removed tensors marked for deletion")
640
641
642def _modify_model_input_type(model, inference_input_type=dtypes.float32):
643  """Modify model input type."""
644
645  if inference_input_type == dtypes.float32:
646    return
647
648  subgraph = model.subgraphs[0]
649  tensors = subgraph.tensors
650  operators = subgraph.operators
651
652  # Find all quantize operators
653  quant_opcode_idxs = []
654  for idx, opcode in enumerate(model.operatorCodes):
655    builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
656    if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
657      quant_opcode_idxs.append(idx)
658  if operators and not quant_opcode_idxs:
659    for input_idx in subgraph.inputs:
660      input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type)
661      if input_type == dtypes.float32:
662        raise ValueError("Model input is not dequantized.")
663    # None of the inputs have float32, then they must be int16, int8, or bool
664    return
665
666  # Validate that the model input is quantized
667  input_quant_ops = []
668  for op in operators:
669    # Find operators that quantize model input
670    if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs:
671      float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
672      # If found, validate that the operator's input type is float
673      float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
674      if float_type != dtypes.float32:
675        if float_type == inference_input_type:
676          continue
677        else:
678          raise ValueError(
679              "Initial model input type must be tf.float32. Expected type for "
680              "tensor with name '{}' is tf.float32, instead type is {}".format(
681                  float_tensor.name, get_tf_type_name(float_type)))
682      # If found, validate that the operator output is quantized and compatible
683      # with the final model input type
684      quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
685      if quant_type not in _MAP_QUANT_TO_IO_TYPES:
686        raise ValueError(
687            "Initial model input is not quantized. Expected type for "
688            "tensor with name '{}' should be in {}, instead type is {}".format(
689                quant_tensor.name,
690                tuple(get_tf_type_name(t) for t in
691                      _MAP_QUANT_TO_IO_TYPES.keys()),
692                get_tf_type_name(quant_type)))
693      else:
694        inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
695        if inference_input_type not in inference_io_types:
696          raise ValueError(
697              "Unsupported `inference_input_type` value. Expected to be in "
698              "{}, instead got {}.".format(
699                  tuple(get_tf_type_name(t) for t in inference_io_types),
700                  get_tf_type_name(inference_input_type)))
701      input_quant_ops.append(op)
702
703  if len(subgraph.inputs) != len(input_quant_ops):
704    logging.warning(
705        "For model inputs containing unsupported operations which cannot be "
706        "quantized, the `inference_input_type` attribute will default to the "
707        "original type."
708        )
709
710  # Modify model input type
711  if inference_input_type == dtypes.uint8:
712    # Change quant op (float to int8) to quant op (uint8 to int8)
713    for op in input_quant_ops:
714      int8_quantization = tensors[op.outputs[0]].quantization
715      uint8_quantization = schema_fb.QuantizationParametersT()
716      uint8_quantization.scale = [int8_quantization.scale[0]]
717      uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
718      tensors[op.inputs[0]].quantization = uint8_quantization
719      tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8
720  elif inference_input_type in _MAP_QUANT_TO_IO_TYPES:
721    # Remove the inputs and the quant operator
722    remove_tensors_idxs = set()
723    for op in input_quant_ops:
724      subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0]
725      remove_tensors_idxs.add(op.inputs[0])
726      operators.remove(op)
727    # Remove tensors marked for deletion.
728    _remove_tensors_from_model(model, remove_tensors_idxs)
729  else:
730    raise ValueError(
731        "Unsupported `inference_input_type` value {}.".format(
732            get_tf_type_name(inference_input_type)))
733
734
735def _modify_model_output_type(model, inference_output_type=dtypes.float32):
736  """Modify model output type."""
737
738  if inference_output_type == dtypes.float32:
739    return
740
741  subgraph = model.subgraphs[0]
742  tensors = subgraph.tensors
743  operators = subgraph.operators
744
745  # Find all dequantize operators
746  dequant_opcode_idxs = []
747  for idx, opcode in enumerate(model.operatorCodes):
748    builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
749    if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE:
750      dequant_opcode_idxs.append(idx)
751  if operators and not dequant_opcode_idxs:
752    for output in subgraph.outputs:
753      output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type)
754      if output_type == dtypes.float32:
755        raise ValueError("Model output is not dequantized.")
756    # None of the outputs have float32, then they must be int16, int8, or bool
757    return
758
759  # Validate that the model output is dequantized
760  output_dequant_ops = []
761  for op in operators:
762    # Find operators that dequantize model output
763    if op.opcodeIndex in dequant_opcode_idxs and \
764        op.outputs[0] in subgraph.outputs:
765      # If found, validate that the operator's output type is float
766      quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
767      float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
768      if float_type != dtypes.float32:
769        if float_type == inference_output_type:
770          continue
771        else:
772          raise ValueError(
773              "Initial model output type must be tf.float32. Expected type for "
774              "tensor with name '{}' is tf.float32, instead type is {}".format(
775                  float_tensor.name, get_tf_type_name(float_type)))
776      # If found, validate that the operator input is quantized and compatible
777      # with the final model output type
778      quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
779      if quant_type not in _MAP_QUANT_TO_IO_TYPES:
780        raise ValueError(
781            "Initial model output is not dequantized. Expected type for "
782            "tensor with name '{}' should be in {}, instead type is {}".format(
783                quant_tensor.name,
784                tuple(get_tf_type_name(t) for t in
785                      _MAP_QUANT_TO_IO_TYPES.keys()),
786                get_tf_type_name(quant_type)))
787      else:
788        inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
789        if inference_output_type not in inference_io_types:
790          raise ValueError(
791              "Unsupported `inference_output_type` value. Expected to be in "
792              "{}, instead got {}.".format(
793                  tuple(get_tf_type_name(t) for t in inference_io_types),
794                  get_tf_type_name(inference_output_type)))
795      output_dequant_ops.append(op)
796
797  if len(subgraph.outputs) != len(output_dequant_ops):
798    logging.warning(
799        "For model outputs containing unsupported operations which cannot be "
800        "quantized, the `inference_output_type` attribute will default to the "
801        "original type."
802        )
803
804  # Modify model output type
805  if inference_output_type == dtypes.uint8:
806    # Find a quantize operator
807    quant_opcode_idx = -1
808    for idx, opcode in enumerate(model.operatorCodes):
809      builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
810      if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
811        quant_opcode_idx = idx
812        break
813    # Create a quantize operator, if none exist
814    if quant_opcode_idx == -1:
815      quant_op = schema_fb.OperatorCodeT()
816      quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE
817      quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE
818      model.operatorCodes.append(quant_op)
819      quant_opcode_idx = len(model.operatorCodes) - 1
820    # Change dequant op (int8 to float) to quant op (int8 to uint8)
821    for op in output_dequant_ops:
822      op.opcodeIndex = quant_opcode_idx
823      int8_quantization = tensors[op.inputs[0]].quantization
824      uint8_quantization = schema_fb.QuantizationParametersT()
825      uint8_quantization.scale = [int8_quantization.scale[0]]
826      uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
827      tensors[op.outputs[0]].quantization = uint8_quantization
828      tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
829  elif inference_output_type in _MAP_QUANT_TO_IO_TYPES:
830    # Remove the outputs and the dequant operator
831    remove_tensors_idxs = set()
832    for op in output_dequant_ops:
833      subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
834      remove_tensors_idxs.add(op.outputs[0])
835      operators.remove(op)
836    # Remove tensors marked for deletion.
837    _remove_tensors_from_model(model, remove_tensors_idxs)
838  else:
839    raise ValueError(
840        "Unsupported `inference_output_type` value {}.".format(
841            get_tf_type_name(inference_output_type)))
842
843
844def modify_model_io_type(
845    model, inference_input_type=dtypes.float32,
846    inference_output_type=dtypes.float32):
847  """Modify the input/output type of a tflite model.
848
849  Args:
850    model: A tflite model.
851    inference_input_type: tf.DType representing modified input type.
852      (default tf.float32. If model input is int8 quantized, it must be in
853      {tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized,
854      it must be in {tf.float32, tf.int16}, else it must be tf.float32)
855    inference_output_type: tf.DType representing modified output type.
856      (default tf.float32. If model output is int8 dequantized, it must be in
857      {tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized,
858      it must be in {tf.float32, tf.int16}, else it must be tf.float32)
859  Returns:
860    A tflite model with modified input/output type.
861
862  Raises:
863    ValueError: If `inference_input_type`/`inference_output_type` is unsupported
864      or a supported integer type is specified for a model whose input/output is
865      not quantized/dequantized.
866    RuntimeError: If the modification was unsuccessful.
867
868  """
869  if inference_input_type == dtypes.float32 and \
870      inference_output_type == dtypes.float32:
871    return model
872
873  model_object = _convert_model_from_bytearray_to_object(model)
874
875  if len(model_object.subgraphs) > 1:
876    raise ValueError("Model must only have one subgraph. Instead, it has "
877                     "{} subgraphs.".format(len(model_object.subgraphs)))
878
879  _modify_model_input_type(model_object, inference_input_type)
880
881  _modify_model_output_type(model_object, inference_output_type)
882
883  return _convert_model_from_object_to_bytearray(model_object)
884