1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Lite tooling helper functionality."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import warnings
22import enum
23from six import PY3
24
25from google.protobuf import text_format as _text_format
26from google.protobuf.message import DecodeError
27from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn  # pylint: disable=unused-import
28from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell  # pylint: disable=unused-import
29from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell  # pylint: disable=unused-import
30from tensorflow.lite.python import lite_constants as constants
31from tensorflow.lite.python.convert import build_toco_convert_protos  # pylint: disable=unused-import
32from tensorflow.lite.python.convert import ConverterError  # pylint: disable=unused-import
33from tensorflow.lite.python.convert import OpsSet
34from tensorflow.lite.python.convert import tensor_name as _tensor_name
35from tensorflow.lite.python.convert import toco_convert  # pylint: disable=unused-import
36from tensorflow.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
37from tensorflow.lite.python.convert import toco_convert_impl as _toco_convert_impl
38from tensorflow.lite.python.convert import toco_convert_protos  # pylint: disable=unused-import
39from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
40from tensorflow.lite.python.convert_saved_model import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
41from tensorflow.lite.python.convert_saved_model import set_tensor_shapes as _set_tensor_shapes
42from tensorflow.lite.python.interpreter import Interpreter  # pylint: disable=unused-import
43from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs  # pylint: disable=unused-import
44from tensorflow.lite.python.op_hint import OpHint  # pylint: disable=unused-import
45from tensorflow.lite.python.optimize import calibrator as _calibrator
46from tensorflow.core.framework import graph_pb2 as _graph_pb2
47from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2
48from tensorflow.core.protobuf import config_pb2 as _config_pb2
49from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
50from tensorflow.python import keras as _keras
51from tensorflow.python.client import session as _session
52from tensorflow.python.eager import def_function as _def_function
53from tensorflow.python.eager import function as _function
54from tensorflow.python.framework import convert_to_constants as _convert_to_constants
55from tensorflow.python.framework import dtypes as _dtypes
56from tensorflow.python.framework import graph_util as _tf_graph_util
57from tensorflow.python.framework import ops as _ops
58from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
59from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
60from tensorflow.python.grappler import tf_optimizer as _tf_optimizer
61from tensorflow.python.lib.io import file_io as _file_io
62from tensorflow.python.saved_model import signature_constants as _signature_constants
63from tensorflow.python.saved_model import tag_constants as _tag_constants
64from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
65from tensorflow.python.util import deprecation as _deprecation
66from tensorflow.python.util.tf_export import tf_export as _tf_export
67
68
69def _run_graph_optimizations(graph_def, input_arrays, output_arrays,
70                             graph=None):
71  """Apply standard TensorFlow optimizations to the graph_def.
72
73  Args:
74    graph_def: Frozen GraphDef to be optimized.
75    input_arrays: List of arrays that are considered inputs of the graph.
76    output_arrays: List of arrays that are considered outputs of the graph.
77    graph: TensorFlow Graph. Required when Eager mode is enabled. (default None)
78
79  Returns:
80    A new, optimized GraphDef.
81  """
82  meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph)
83
84  # We need to add a collection called 'train_op' so that grappler
85  # knows what the outputs are.
86  fetch_collection = _meta_graph_pb2.CollectionDef()
87  for array in input_arrays + output_arrays:
88    fetch_collection.node_list.value.append(array.name)
89  meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
90
91  config = _config_pb2.ConfigProto()
92  rewrite_options = config.graph_options.rewrite_options
93  rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.ON
94  # Avoid remapping as it creates ops like _FusedConv2D, which are not
95  # supported by TF Lite.
96  rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF
97  return _tf_optimizer.OptimizeGraph(config, meta_graph)
98
99
100@_tf_export("lite.Optimize")
101class Optimize(enum.Enum):
102  """Enum defining the optimizations to apply when generating tflite graphs.
103
104  Some optimizations may come at the cost of accuracy.
105  """
106
107  # Optimize for size.
108  #
109  # Optimizations that reduce the size of the model.
110  # The model size will be reduced. Optimizations can include quantizing the
111  # weights of the floating point model.
112  OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE"
113
114  # Optimize for latency.
115  #
116  # Optimizations that reduce the latency of the model.
117  # The model latency will be reduced. Optimizations can include quantizing the
118  # weights of the floating point model.
119  OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY"
120
121  def __str__(self):
122    return self.value
123
124
125@_tf_export("lite.RepresentativeDataset")
126class RepresentativeDataset(object):
127  """Representative dataset to evaluate optimizations.
128
129  A representative dataset that can be used to evaluate optimizations by the
130  converter. E.g. converter can use these examples to estimate (min, max) ranges
131  by calibrating the model on inputs. This can allow converter to quantize a
132  converted floating point model.
133  """
134
135  def __init__(self, input_gen, output_gen=None):
136    """Creates a representative dataset.
137
138    Args:
139      input_gen: an input generator that can be used to generate input samples
140        for the model. This must be a callable object that returns an object
141        that supports the `iter()` protocol (e.g. a generator function). The
142        elements generated must have same type and shape as inputs to the model.
143      output_gen: (optional) an output generator that can be used to generate
144        output samples for the model. This must be a callable object that
145        returns an object that supports the `iter()` protocol (e.g. a generator
146        function). The elements generated must have same type and shape as
147        outputs to the model. (default None)
148    """
149    self.input_gen = input_gen
150    self.output_gen = output_gen
151
152
153@_tf_export("lite.TargetSpec")
154class TargetSpec(object):
155  """Specification of target device.
156
157  Details about target device. Converter optimizes the generated model for
158  specific device.
159
160  Attributes:
161    supported_ops: Experimental flag, subject to change. Set of OpsSet options
162      supported by the device. (default set([OpsSet.TFLITE_BUILTINS]))
163  """
164
165  def __init__(self, supported_ops=None):
166    if supported_ops is None:
167      supported_ops = set([OpsSet.TFLITE_BUILTINS])
168    self.supported_ops = supported_ops
169
170
171@_tf_export("lite.TFLiteConverter", v1=[])
172class TFLiteConverterV2(object):
173  """Converts a TensorFlow model into TensorFlow Lite model.
174
175  Attributes:
176    allow_custom_ops: Boolean indicating whether to allow custom operations.
177      When false any unknown operation is an error. When true, custom ops are
178      created for any op that is unknown. The developer will need to provide
179      these to the TensorFlow Lite runtime with a custom resolver. (default
180      False)
181    target_spec: Experimental flag, subject to change. Specification of target
182      device.
183    optimizations: Experimental flag, subject to change, A list of optimizations
184      to apply when converting the model. The converter applies the
185      optimizations by giving priority to the optimizations specified earlier in
186      the list. E.g. `[optimize.OPTIMIZE_FOR_SIZE,
187      optimize.OPTIMIZE_FOR_LATENCY]` requires the converter to do both size and
188      latency optimizations giving priority to size optimizations over latency
189      optimizations.
190    representative_dataset: A representative dataset that can be used to
191      generate input and output samples for the model. The converter can use the
192      dataset to evaluate different optimizations.
193
194  Example usage:
195
196    ```python
197    # Converting a GraphDef from a ConcreteFunction.
198    converter = lite.TFLiteConverter.from_concrete_function(func)
199    tflite_model = converter.convert()
200    open("converted_model.tflite", "wb").write(tflite_model)
201    ```
202  """
203
204  def __init__(self, func):
205    """Constructor for TFLiteConverter.
206
207    Args:
208      func: TensorFlow ConcreteFunction.
209    """
210    self._func = func
211    self.allow_custom_ops = False
212    self.target_spec = TargetSpec()
213    self.representative_dataset = None
214    self.optimizations = []
215
216  @classmethod
217  def from_concrete_function(cls, func):
218    """Creates a TFLiteConverter class from a ConcreteFunction.
219
220    Args:
221      func: TensorFlow ConcreteFunction.
222
223    Returns:
224      TFLiteConverter class.
225    """
226    if not isinstance(func, _function.ConcreteFunction):
227      message = "This function takes in a ConcreteFunction."
228      if isinstance(func, _def_function.Function):
229        message += (" To get the ConcreteFunction from a Function,"
230                    " call from_concrete_function.")
231      raise ValueError(message)
232    return cls(func)
233
234  def convert(self):
235    """Converts a TensorFlow GraphDef based on instance variables.
236
237    Returns:
238      The converted data in serialized format.
239
240    Raises:
241      ValueError:
242        Input shape is not specified.
243        None value for dimension in input_tensor.
244    """
245    frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
246        self._func)
247    input_tensors = [
248        tensor for tensor in frozen_func.inputs
249        if tensor.dtype != _dtypes.resource
250    ]
251    output_tensors = frozen_func.outputs
252
253    # Run a Grappler pass.
254    graph_def = _run_graph_optimizations(frozen_func.graph.as_graph_def(),
255                                         input_tensors, output_tensors,
256                                         frozen_func.graph)
257
258    # Checks dimensions in input tensor.
259    for tensor in input_tensors:
260      # Note that shape_list might be empty for scalar shapes.
261      shape_list = tensor.shape.as_list()
262      if None in shape_list[1:]:
263        raise ValueError(
264            "None is only supported in the 1st dimension. Tensor '{0}' has "
265            "invalid shape '{1}'.".format(_tensor_name(tensor), shape_list))
266      elif shape_list and shape_list[0] is None:
267        # Set the batch size to 1 if undefined.
268        shape = tensor.shape.as_list()
269        shape[0] = 1
270        tensor.set_shape(shape)
271
272    if self.representative_dataset:
273      if not isinstance(self.representative_dataset, RepresentativeDataset):
274        raise TypeError("`representative_dataset` must be an instance of "
275                        "`RepresentativeDataset`")
276      if self.representative_dataset.input_gen is None:
277        raise ValueError(
278            "Provide an input generator for `representative_dataset`")
279
280    # TODO(shashishekhar): For now use optimizations order is ignored.
281    # Both size and latency optimizations decide whether to apply post
282    # training optimizations.
283    post_training_optimize = bool(
284        len(
285            set(self.optimizations)
286            & set([Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE])))
287    # Do weights only quantization if there is no dataset for calibration.
288    weights_only_quantize_flag = (
289        post_training_optimize and (self.representative_dataset is None))
290
291    converter_kwargs = {
292        "input_format": constants.TENSORFLOW_GRAPHDEF,
293        "allow_custom_ops": self.allow_custom_ops,
294        "post_training_quantize": weights_only_quantize_flag,
295        "target_ops": self.target_spec.supported_ops,
296    }
297
298    # Converts model.
299    result = _toco_convert_impl(
300        input_data=graph_def,
301        input_tensors=input_tensors,
302        output_tensors=output_tensors,
303        **converter_kwargs)
304
305    if self.representative_dataset and post_training_optimize:
306      calibrate_quantize = _calibrator.Calibrator(result)
307      result = calibrate_quantize.calibrate_and_quantize(
308          self.representative_dataset.input_gen)
309
310    return result
311
312
313@_tf_export(v1=["lite.TFLiteConverter"])
314class TFLiteConverter(object):
315  """Convert a TensorFlow model into `output_format` using TOCO.
316
317  This is used to convert from a TensorFlow GraphDef or SavedModel into either a
318  TFLite FlatBuffer or graph visualization.
319
320  Attributes:
321
322    inference_type: Target data type of real-number arrays in the output file.
323      Must be `{tf.float32, tf.uint8}`. (default tf.float32)
324    inference_input_type: Target data type of real-number input arrays. Allows
325      for a different type for input arrays in the case of quantization.
326      Must be `{tf.float32, tf.uint8}`. (default `inference_type`)
327    output_format: Output file format. Currently must be `{TFLITE,
328      GRAPHVIZ_DOT}`. (default TFLITE)
329    quantized_input_stats: Dict of strings representing input tensor names
330      mapped to tuple of floats representing the mean and standard deviation
331      of the training data (e.g., {"foo" : (0., 1.)}). Only need if
332      `inference_input_type` is `QUANTIZED_UINT8`.
333      real_input_value = (quantized_input_value - mean_value) / std_dev_value.
334      (default {})
335    default_ranges_stats: Tuple of integers representing (min, max) range values
336      for all arrays without a specified range. Intended for experimenting with
337      quantization via "dummy quantization". (default None)
338    drop_control_dependency: Boolean indicating whether to drop control
339      dependencies silently. This is due to TFLite not supporting control
340      dependencies. (default True)
341    reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
342      nodes in unexpected locations. Used when the location of the FakeQuant
343      nodes is preventing graph transformations necessary to convert the graph.
344      Results in a graph that differs from the quantized training graph,
345      potentially causing differing arithmetic behavior. (default False)
346    change_concat_input_ranges: Boolean to change behavior of min/max ranges for
347      inputs and outputs of the concat operator for quantized models. Changes
348      the ranges of concat operator overlap when true. (default False)
349    allow_custom_ops: Boolean indicating whether to allow custom operations.
350      When false any unknown operation is an error. When true, custom ops are
351      created for any op that is unknown. The developer will need to provide
352      these to the TensorFlow Lite runtime with a custom resolver.
353      (default False)
354    post_training_quantize: deprecated, please specify
355     `[optimize.OPTIMIZE_FOR_SIZE]` for `optimizations` instead. Boolean
356     indicating whether to quantize the weights of the converted float model.
357     Model size will be reduced and there will be latency improvements
358     (at the cost of accuracy). (default False)
359    dump_graphviz_dir: Full filepath of folder to dump the graphs at various
360      stages of processing GraphViz .dot files. Preferred over
361      --output_format=GRAPHVIZ_DOT in order to keep the requirements of the
362      output file. (default None)
363    dump_graphviz_video: Boolean indicating whether to dump the graph after
364      every graph transformation. (default False)
365    target_ops: Experimental flag, subject to change. Set of OpsSet
366      options indicating which converter to use.
367      (default set([OpsSet.TFLITE_BUILTINS]))
368    optimizations: Experimental flag, subject to change, A list of
369      optimizations to apply when converting the model. The converter applies
370      the optimizations by giving priority to the optimizations specified
371      earlier in the list. E.g.
372      `[optimize.OPTIMIZE_FOR_SIZE, optimize.OPTIMIZE_FOR_LATENCY]` requires
373      the converter to do both size and latency optimizations giving priority
374      to size optimizations over latency optimizations.
375    representative_dataset: A representative dataset that can be used to
376      generate input and output samples for the model. The converter can use
377      the dataset to evaluate different optimizations.
378
379  Example usage:
380
381    ```python
382    # Converting a GraphDef from session.
383    converter = lite.TFLiteConverter.from_session(sess, in_tensors, out_tensors)
384    tflite_model = converter.convert()
385    open("converted_model.tflite", "wb").write(tflite_model)
386
387    # Converting a GraphDef from file.
388    converter = lite.TFLiteConverter.from_frozen_graph(
389      graph_def_file, input_arrays, output_arrays)
390    tflite_model = converter.convert()
391    open("converted_model.tflite", "wb").write(tflite_model)
392
393    # Converting a SavedModel.
394    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
395    tflite_model = converter.convert()
396
397    # Converting a tf.keras model.
398    converter = lite.TFLiteConverter.from_keras_model_file(keras_model)
399    tflite_model = converter.convert()
400    ```
401  """
402
403  def __init__(self,
404               graph_def,
405               input_tensors,
406               output_tensors,
407               input_arrays_with_shape=None,
408               output_arrays=None):
409    """Constructor for TFLiteConverter.
410
411    Args:
412      graph_def: Frozen TensorFlow GraphDef.
413      input_tensors: List of input tensors. Type and shape are computed using
414        `foo.shape` and `foo.dtype`.
415      output_tensors: List of output tensors (only .name is used from this).
416      input_arrays_with_shape: Tuple of strings representing input tensor names
417        and list of integers representing input shapes
418        (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
419          into TensorFlow and when `input_tensors` and `output_tensors` are
420          None. (default None)
421      output_arrays: List of output tensors to freeze graph with. Use only when
422        graph cannot be loaded into TensorFlow and when `input_tensors` and
423        `output_tensors` are None. (default None)
424
425    Raises:
426      ValueError: Invalid arguments.
427    """
428    self._graph_def = graph_def
429    self._input_tensors = input_tensors
430    self._output_tensors = output_tensors
431    self.inference_type = constants.FLOAT
432    self.inference_input_type = None
433    self.output_format = constants.TFLITE
434    self.quantized_input_stats = {}
435    self.default_ranges_stats = None
436    self.drop_control_dependency = True
437    self.reorder_across_fake_quant = False
438    self.change_concat_input_ranges = False
439    self.allow_custom_ops = False
440    self._post_training_quantize = False
441    self.dump_graphviz_dir = None
442    self.dump_graphviz_video = False
443    self.target_ops = set([OpsSet.TFLITE_BUILTINS])
444    self.representative_dataset = None
445    self.optimizations = []
446
447    # Attributes are used by models that cannot be loaded into TensorFlow.
448    if not self._has_valid_tensors():
449      if not input_arrays_with_shape or not output_arrays:
450        raise ValueError(
451            "If input_tensors and output_tensors are None, both "
452            "input_arrays_with_shape and output_arrays must be defined.")
453      self._input_arrays_with_shape = input_arrays_with_shape
454      self._output_arrays = output_arrays
455
456  @classmethod
457  def from_session(cls, sess, input_tensors, output_tensors):
458    """Creates a TFLiteConverter class from a TensorFlow Session.
459
460    Args:
461      sess: TensorFlow Session.
462      input_tensors: List of input tensors. Type and shape are computed using
463        `foo.shape` and `foo.dtype`.
464      output_tensors: List of output tensors (only .name is used from this).
465
466    Returns:
467      TFLiteConverter class.
468    """
469    graph_def = _freeze_graph(sess, output_tensors)
470    return cls(graph_def, input_tensors, output_tensors)
471
472  @classmethod
473  def from_frozen_graph(cls,
474                        graph_def_file,
475                        input_arrays,
476                        output_arrays,
477                        input_shapes=None):
478    """Creates a TFLiteConverter class from a file containing a frozen GraphDef.
479
480    Args:
481      graph_def_file: Full filepath of file containing frozen GraphDef.
482      input_arrays: List of input tensors to freeze graph with.
483      output_arrays: List of output tensors to freeze graph with.
484      input_shapes: Dict of strings representing input tensor names to list of
485        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
486        Automatically determined when input shapes is None (e.g., {"foo" :
487          None}). (default None)
488
489    Returns:
490      TFLiteConverter class.
491
492    Raises:
493      IOError:
494        File not found.
495        Unable to parse input file.
496      ValueError:
497        The graph is not frozen.
498        input_arrays or output_arrays contains an invalid tensor name.
499        input_shapes is not correctly defined when required
500    """
501    with _ops.Graph().as_default():
502      with _session.Session() as sess:
503        # Read GraphDef from file.
504        if not _file_io.file_exists(graph_def_file):
505          raise IOError("File '{0}' does not exist.".format(graph_def_file))
506        with _file_io.FileIO(graph_def_file, "rb") as f:
507          file_content = f.read()
508
509        try:
510          graph_def = _graph_pb2.GraphDef()
511          graph_def.ParseFromString(file_content)
512        except (_text_format.ParseError, DecodeError):
513          try:
514            print("Ignore 'tcmalloc: large alloc' warnings.")
515
516            if not isinstance(file_content, str):
517              if PY3:
518                file_content = file_content.decode("utf-8")
519              else:
520                file_content = file_content.encode("utf-8")
521            graph_def = _graph_pb2.GraphDef()
522            _text_format.Merge(file_content, graph_def)
523          except (_text_format.ParseError, DecodeError):
524            raise IOError(
525                "Unable to parse input file '{}'.".format(graph_def_file))
526
527        # Handles models with custom TFLite ops that cannot be resolved in
528        # TensorFlow.
529        load_model_in_session = True
530        try:
531          _import_graph_def(graph_def, name="")
532        except _NotFoundError:
533          load_model_in_session = False
534
535        if load_model_in_session:
536          # Check if graph is frozen.
537          if not _is_frozen_graph(sess):
538            raise ValueError("Please freeze the graph using freeze_graph.py.")
539
540          # Get input and output tensors.
541          input_tensors = _get_tensors_from_tensor_names(
542              sess.graph, input_arrays)
543          output_tensors = _get_tensors_from_tensor_names(
544              sess.graph, output_arrays)
545          _set_tensor_shapes(input_tensors, input_shapes)
546
547          return cls(sess.graph_def, input_tensors, output_tensors)
548        else:
549          if not input_shapes:
550            raise ValueError("input_shapes must be defined for this model.")
551          if set(input_arrays) != set(input_shapes.keys()):
552            raise ValueError("input_shapes must contain a value for each item "
553                             "in input_array.")
554
555          input_arrays_with_shape = [
556              (name, input_shapes[name]) for name in input_arrays
557          ]
558          return cls(
559              graph_def,
560              input_tensors=None,
561              output_tensors=None,
562              input_arrays_with_shape=input_arrays_with_shape,
563              output_arrays=output_arrays)
564
565  @classmethod
566  def from_saved_model(cls,
567                       saved_model_dir,
568                       input_arrays=None,
569                       input_shapes=None,
570                       output_arrays=None,
571                       tag_set=None,
572                       signature_key=None):
573    """Creates a TFLiteConverter class from a SavedModel.
574
575    Args:
576      saved_model_dir: SavedModel directory to convert.
577      input_arrays: List of input tensors to freeze graph with. Uses input
578        arrays from SignatureDef when none are provided. (default None)
579      input_shapes: Dict of strings representing input tensor names to list of
580        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
581        Automatically determined when input shapes is None (e.g., {"foo" :
582          None}). (default None)
583      output_arrays: List of output tensors to freeze graph with. Uses output
584        arrays from SignatureDef when none are provided. (default None)
585      tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
586        analyze. All tags in the tag set must be present. (default set("serve"))
587      signature_key: Key identifying SignatureDef containing inputs and outputs.
588        (default DEFAULT_SERVING_SIGNATURE_DEF_KEY)
589
590    Returns:
591      TFLiteConverter class.
592    """
593    if tag_set is None:
594      tag_set = set([_tag_constants.SERVING])
595    if signature_key is None:
596      signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
597
598    result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
599                                 output_arrays, tag_set, signature_key)
600    return cls(
601        graph_def=result[0], input_tensors=result[1], output_tensors=result[2])
602
603  @classmethod
604  def from_keras_model_file(cls,
605                            model_file,
606                            input_arrays=None,
607                            input_shapes=None,
608                            output_arrays=None):
609    """Creates a TFLiteConverter class from a tf.keras model file.
610
611    Args:
612      model_file: Full filepath of HDF5 file containing the tf.keras model.
613      input_arrays: List of input tensors to freeze graph with. Uses input
614        arrays from SignatureDef when none are provided. (default None)
615      input_shapes: Dict of strings representing input tensor names to list of
616        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
617        Automatically determined when input shapes is None (e.g., {"foo" :
618          None}). (default None)
619      output_arrays: List of output tensors to freeze graph with. Uses output
620        arrays from SignatureDef when none are provided. (default None)
621
622    Returns:
623      TFLiteConverter class.
624    """
625    _keras.backend.clear_session()
626    _keras.backend.set_learning_phase(False)
627    keras_model = _keras.models.load_model(model_file)
628    sess = _keras.backend.get_session()
629
630    # Get input and output tensors.
631    if input_arrays:
632      input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
633    else:
634      input_tensors = keras_model.inputs
635
636    if output_arrays:
637      output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
638    else:
639      output_tensors = keras_model.outputs
640    _set_tensor_shapes(input_tensors, input_shapes)
641
642    graph_def = _freeze_graph(sess, output_tensors)
643    return cls(graph_def, input_tensors, output_tensors)
644
645  def __setattr__(self, name, value):
646    if name == "post_training_quantize":
647      warnings.warn("Property %s is deprecated, "
648                    "please use optimizations=[Optimize.OPTIMIZE_FOR_SIZE]"
649                    " instead." % name)
650      if value:
651        # Use OPTIMIZE_FOR_SIZE for post training for now.
652        self.optimizations = [Optimize.OPTIMIZE_FOR_SIZE]
653      else:
654        self.optimizations = []
655      return
656    object.__setattr__(self, name, value)
657
658  def __getattribute__(self, name):
659    if name == "post_training_quantize":
660      warnings.warn("Property %s is deprecated, "
661                    "please use optimizations=[Optimize.OPTIMIZE_FOR_SIZE]"
662                    " instead." % name)
663      return Optimize.OPTIMIZE_FOR_SIZE in set(self.optimizations)
664    return object.__getattribute__(self, name)
665
666  def convert(self):
667    """Converts a TensorFlow GraphDef based on instance variables.
668
669    Returns:
670      The converted data in serialized format. Either a TFLite Flatbuffer or a
671      Graphviz graph depending on value in `output_format`.
672
673    Raises:
674      ValueError:
675        Input shape is not specified.
676        None value for dimension in input_tensor.
677    """
678    # Checks dimensions in input tensor.
679    if self._has_valid_tensors():
680      for tensor in self._input_tensors:
681        shape = tensor.shape
682        if not shape:
683          raise ValueError("Provide an input shape for input array "
684                           "'{0}'.".format(_tensor_name(tensor)))
685        # Note that shape_list might be empty for scalar shapes.
686        shape_list = shape.as_list()
687        if None in shape_list[1:]:
688          raise ValueError(
689              "None is only supported in the 1st dimension. Tensor '{0}' has "
690              "invalid shape '{1}'.".format(_tensor_name(tensor), shape_list))
691        elif shape_list and shape_list[0] is None:
692          self._set_batch_size(batch_size=1)
693
694    # Get quantization stats. Ensures there is one stat per name if the stats
695    # are specified.
696    if self.quantized_input_stats:
697      quantized_stats = []
698      invalid_stats = []
699      for name in self.get_input_arrays():
700        if name in self.quantized_input_stats:
701          quantized_stats.append(self.quantized_input_stats[name])
702        else:
703          invalid_stats.append(name)
704
705      if invalid_stats:
706        raise ValueError("Quantization input stats are not available for input "
707                         "tensors '{0}'.".format(",".join(invalid_stats)))
708    else:
709      quantized_stats = None
710    if self.representative_dataset:
711      if not isinstance(self.representative_dataset, RepresentativeDataset):
712        raise TypeError(
713            "representative_dataset must be an instance of "
714            "RepresentativeDataset")
715      if self.representative_dataset.input_gen is None:
716        raise ValueError(
717            "Provide an input generator for representative_dataset")
718
719    # TODO(shashishekhar): For now use optimizations order is ignored.
720    # Both size and latency optimizations decide whether to apply post
721    # training optimizations.
722    post_training_optimize = bool(
723        len(set(self.optimizations) & set([Optimize.OPTIMIZE_FOR_LATENCY,
724                                           Optimize.OPTIMIZE_FOR_SIZE])))
725    # Do weights only quantization if there is no dataset for calibration.
726    weights_only_quantize_flag = (
727        post_training_optimize and (self.representative_dataset is None))
728
729    converter_kwargs = {
730        "inference_type": self.inference_type,
731        "inference_input_type": self.inference_input_type,
732        "input_format": constants.TENSORFLOW_GRAPHDEF,
733        "output_format": self.output_format,
734        "quantized_input_stats": quantized_stats,
735        "default_ranges_stats": self.default_ranges_stats,
736        "drop_control_dependency": self.drop_control_dependency,
737        "reorder_across_fake_quant": self.reorder_across_fake_quant,
738        "change_concat_input_ranges": self.change_concat_input_ranges,
739        "allow_custom_ops": self.allow_custom_ops,
740        "post_training_quantize": weights_only_quantize_flag,
741        "target_ops": self.target_ops,
742        "dump_graphviz_dir": self.dump_graphviz_dir,
743        "dump_graphviz_video": self.dump_graphviz_video
744    }
745
746    optimized_graph = None
747    if self.inference_type == constants.QUANTIZED_UINT8:
748      optimized_graph = self._graph_def
749    else:
750      try:
751        optimized_graph = _run_graph_optimizations(
752            self._graph_def, self._input_tensors, self._output_tensors)
753      except Exception:
754        optimized_graph = self._graph_def
755
756    # Converts model.
757    if self._has_valid_tensors():
758      result = _toco_convert_impl(
759          input_data=optimized_graph,
760          input_tensors=self._input_tensors,
761          output_tensors=self._output_tensors,
762          **converter_kwargs)
763    else:
764      result = _toco_convert_graph_def(
765          input_data=optimized_graph,
766          input_arrays_with_shape=self._input_arrays_with_shape,
767          output_arrays=self._output_arrays,
768          **converter_kwargs)
769
770    if self.representative_dataset and post_training_optimize:
771      calibrate_quantize = _calibrator.Calibrator(result)
772      result = calibrate_quantize.calibrate_and_quantize(
773          self.representative_dataset.input_gen)
774
775    return result
776
777  def get_input_arrays(self):
778    """Returns a list of the names of the input tensors.
779
780    Returns:
781      List of strings.
782    """
783    if self._has_valid_tensors():
784      return [_tensor_name(tensor) for tensor in self._input_tensors]
785    else:
786      return [name for name, _ in self._input_arrays_with_shape]
787
788  def _has_valid_tensors(self):
789    """Checks if the input and output tensors have been initialized.
790
791    Returns:
792      Bool.
793    """
794    return self._input_tensors and self._output_tensors
795
796  def _set_batch_size(self, batch_size):
797    """Sets the first dimension of the input tensor to `batch_size`.
798
799    Args:
800      batch_size: Batch size for the model. Replaces the first dimension of an
801        input size array if undefined. (default 1)
802
803    Raises:
804      ValueError: input_tensor is not defined.
805    """
806    if not self._has_valid_tensors():
807      raise ValueError("The batch size cannot be set for this model. Please "
808                       "use input_shapes parameter.")
809
810    for tensor in self._input_tensors:
811      shape = tensor.shape.as_list()
812      shape[0] = batch_size
813      tensor.set_shape(shape)
814
815
816@_tf_export(v1=["lite.TocoConverter"])
817class TocoConverter(object):
818  """Convert a TensorFlow model into `output_format` using TOCO.
819
820  This class has been deprecated. Please use `lite.TFLiteConverter` instead.
821  """
822
823  @classmethod
824  @_deprecation.deprecated(None,
825                           "Use `lite.TFLiteConverter.from_session` instead.")
826  def from_session(cls, sess, input_tensors, output_tensors):
827    """Creates a TocoConverter class from a TensorFlow Session."""
828    return TFLiteConverter.from_session(sess, input_tensors, output_tensors)
829
830  @classmethod
831  @_deprecation.deprecated(
832      None, "Use `lite.TFLiteConverter.from_frozen_graph` instead.")
833  def from_frozen_graph(cls,
834                        graph_def_file,
835                        input_arrays,
836                        output_arrays,
837                        input_shapes=None):
838    """Creates a TocoConverter class from a file containing a frozen graph."""
839    return TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays,
840                                             output_arrays, input_shapes)
841
842  @classmethod
843  @_deprecation.deprecated(
844      None, "Use `lite.TFLiteConverter.from_saved_model` instead.")
845  def from_saved_model(cls,
846                       saved_model_dir,
847                       input_arrays=None,
848                       input_shapes=None,
849                       output_arrays=None,
850                       tag_set=None,
851                       signature_key=None):
852    """Creates a TocoConverter class from a SavedModel."""
853    return TFLiteConverter.from_saved_model(saved_model_dir, input_arrays,
854                                            input_shapes, output_arrays,
855                                            tag_set, signature_key)
856
857  @classmethod
858  @_deprecation.deprecated(
859      None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.")
860  def from_keras_model_file(cls,
861                            model_file,
862                            input_arrays=None,
863                            input_shapes=None,
864                            output_arrays=None):
865    """Creates a TocoConverter class from a tf.keras model file."""
866    return TFLiteConverter.from_keras_model_file(model_file, input_arrays,
867                                                 input_shapes, output_arrays)
868
869
870def _is_frozen_graph(sess):
871  """Determines if the graph is frozen.
872
873  Determines if a graph has previously been frozen by checking for any
874  operations of type Variable*. If variables are found, the graph is not frozen.
875
876  Args:
877    sess: TensorFlow Session.
878
879  Returns:
880    Bool.
881  """
882  for op in sess.graph.get_operations():
883    if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
884      return False
885  return True
886
887
888def _freeze_graph(sess, output_tensors):
889  """Returns a frozen GraphDef.
890
891  Freezes a graph with Variables in it. Otherwise the existing GraphDef is
892  returned.
893
894  Args:
895    sess: TensorFlow Session.
896    output_tensors: List of output tensors (only .name is used from this).
897
898  Returns:
899    Frozen GraphDef.
900  """
901  if not _is_frozen_graph(sess):
902    output_arrays = [_tensor_name(tensor) for tensor in output_tensors]
903    return _tf_graph_util.convert_variables_to_constants(
904        sess, sess.graph_def, output_arrays)
905  else:
906    return sess.graph_def
907