1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for post training quantization with calibration."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.util.lazy_loader import LazyLoader
21
22# Lazy load since some of the performance benchmark skylark rules
23# break dependencies. Must use double quotes to match code internal rewrite
24# rule.
25_calibration_wrapper = LazyLoader(
26    "_calibration_wrapper", globals(),
27    "tensorflow.lite.python.optimize."
28    "tensorflow_lite_wrap_calibration_wrapper")
29
30
31class Calibrator(object):
32  """Calibrates a floating point model and then quantizes it.
33
34  This is an internal class, not a public interface.
35  """
36
37  def __init__(self, model_content):
38    """Constructor.
39
40    Args:
41      model_content: Content of a TF-Lite Flatbuffer file.
42
43    Raises:
44      ValueError: If the calibrator was unable to open the model.
45    """
46    if not model_content:
47      raise ValueError("`model_content` must be specified.")
48    try:
49      self._calibrator = (_calibration_wrapper.CalibrationWrapper
50                          .CreateWrapperCPPFromBuffer(model_content))
51    except Exception as e:
52      raise ValueError("Failed to parse the model: %s." % e)
53    if not self._calibrator:
54      raise ValueError("Failed to parse the model.")
55
56  def calibrate_and_quantize(self, dataset_gen):
57    """Calibrates the model with specified generator and then quantizes it.
58
59    Returns:
60      A quantized model.
61
62    Args:
63      dataset_gen: A generator that generates calibration samples.
64    """
65    self._calibrator.Prepare()
66    for calibration_sample in dataset_gen():
67      self._calibrator.FeedTensor(calibration_sample)
68    return self._calibrator.QuantizeModel()
69