1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/quantization_utils.h"
16 #include "absl/memory/memory.h"
17 #include "tensorflow/lite/c/c_api_internal.h"
18 #include "tensorflow/lite/kernels/internal/round.h"
19 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
20 #include "tensorflow/lite/kernels/internal/types.h"
21
22 #include <cmath>
23 #include <cstdint>
24
25 namespace tflite {
26 namespace optimize {
27 namespace utils {
28
29 namespace {
30 const int8_t kMinQuantizedValue = -127;
31 const int8_t kMaxQuantizedValue = 127;
32 } // namespace
33
NumElements(const TensorT & tensor,uint64_t * num_elements)34 TfLiteStatus NumElements(const TensorT& tensor, uint64_t* num_elements) {
35 if (tensor.shape.empty()) {
36 return kTfLiteError;
37 }
38 *num_elements = 1;
39 for (const uint64_t dim : tensor.shape) {
40 *num_elements *= dim;
41 }
42 return kTfLiteOk;
43 }
44
45 // Nudge min and max so that floating point 0 falls exactly on a quantized
46 // value, returning the nudges scale and zero_point.
47 //
48 // Although this code originates from FakeQuantization in quantized training,
49 // we may deviate from that implementation as we please since we do not fine
50 // tune the weights with quantized training.
GetAsymmetricQuantizationParams(float min,float max,const int quant_min,const int quant_max,QuantizationParametersT * quantization_params)51 void GetAsymmetricQuantizationParams(
52 float min, float max, const int quant_min, const int quant_max,
53 QuantizationParametersT* quantization_params) {
54 const float quant_min_float = static_cast<float>(quant_min);
55 const float quant_max_float = static_cast<float>(quant_max);
56 // Adjust the boundaries to guarantee 0 is included.
57 min = std::min(static_cast<float>(min), 0.0f);
58 max = std::max(static_cast<float>(max), 0.0f);
59 const float scale = (max - min) / (quant_max_float - quant_min_float);
60 // Scale can be zero if min and max are exactly 0.0f.
61 float zero_point_from_min = quant_min_float;
62 if (scale != 0) {
63 zero_point_from_min = quant_min_float - min / scale;
64 }
65 int64_t zero_point;
66 if (zero_point_from_min < quant_min_float) {
67 zero_point = static_cast<int64_t>(quant_min);
68 } else if (zero_point_from_min > quant_max_float) {
69 zero_point = static_cast<int64_t>(quant_max);
70 } else {
71 zero_point = static_cast<int64_t>(std::round(zero_point_from_min));
72 }
73 quantization_params->min = std::vector<float>(1, min);
74 quantization_params->max = std::vector<float>(1, max);
75 quantization_params->scale = std::vector<float>(1, scale);
76 quantization_params->zero_point = std::vector<int64_t>(1, zero_point);
77 }
78
79 // Per-channel quantize a tensor at the given index and returns both scales and
80 // quantized values.
SymmetricPerChannelQuantization(const float * const input,const std::vector<int> & dimension,int32_t channel_dim_index,std::vector<float> * output_scales,std::vector<int8_t> * output_value)81 void SymmetricPerChannelQuantization(const float* const input,
82 const std::vector<int>& dimension,
83 int32_t channel_dim_index,
84 std::vector<float>* output_scales,
85 std::vector<int8_t>* output_value) {
86 const int32_t channel_dim_size = dimension[channel_dim_index];
87 std::vector<float> min_vals(channel_dim_size);
88 std::vector<float> max_vals(channel_dim_size);
89 std::vector<bool> has_min_max_value(channel_dim_size, false);
90 int indices[4];
91 RuntimeShape tensor_dims{dimension[0], dimension[1], dimension[2],
92 dimension[3]};
93
94 // Compute min max ranges per channel
95 for (indices[0] = 0; indices[0] < dimension[0]; indices[0]++) {
96 for (indices[1] = 0; indices[1] < dimension[1]; indices[1]++) {
97 for (indices[2] = 0; indices[2] < dimension[2]; indices[2]++) {
98 for (indices[3] = 0; indices[3] < dimension[3]; indices[3]++) {
99 int channel_idx = indices[channel_dim_index];
100 const float val = input[Offset(tensor_dims, indices)];
101 if (has_min_max_value[channel_idx]) {
102 if (min_vals[channel_idx] > val) {
103 min_vals[channel_idx] = val;
104 } else if (max_vals[channel_idx] < val) {
105 max_vals[channel_idx] = val;
106 }
107 } else {
108 min_vals[channel_idx] = val;
109 max_vals[channel_idx] = val;
110 has_min_max_value[channel_idx] = true;
111 }
112 }
113 }
114 }
115 }
116
117 // Calculate scales per channel
118 std::vector<float> scale_invs(channel_dim_size);
119 const float half_scale = kMaxQuantizedValue;
120 for (size_t channel_idx = 0; channel_idx < channel_dim_size; channel_idx++) {
121 const float half_range = std::max(std::abs(min_vals[channel_idx]),
122 std::abs(max_vals[channel_idx]));
123 output_scales->at(channel_idx) = half_range / half_scale;
124 if (half_range == 0) {
125 scale_invs[channel_idx] = 0;
126 } else {
127 scale_invs[channel_idx] = half_scale / half_range;
128 }
129 }
130
131 // Quantize the values.
132 SymmetricPerChannelQuantizeValues(input, scale_invs, dimension,
133 channel_dim_index, output_value);
134 }
135
SymmetricPerChannelQuantizeValues(const float * const input,const std::vector<float> & scales_inv,const std::vector<int> & dimension,int32_t channel_dim_index,std::vector<int8_t> * output_value)136 void SymmetricPerChannelQuantizeValues(const float* const input,
137 const std::vector<float>& scales_inv,
138 const std::vector<int>& dimension,
139 int32_t channel_dim_index,
140 std::vector<int8_t>* output_value) {
141 // Quantize the values.
142 int indices[4];
143 RuntimeShape tensor_dims{dimension[0], dimension[1], dimension[2],
144 dimension[3]};
145 for (indices[0] = 0; indices[0] < dimension[0]; indices[0]++) {
146 for (indices[1] = 0; indices[1] < dimension[1]; indices[1]++) {
147 for (indices[2] = 0; indices[2] < dimension[2]; indices[2]++) {
148 for (indices[3] = 0; indices[3] < dimension[3]; indices[3]++) {
149 int channel_idx = indices[channel_dim_index];
150 int index = Offset(tensor_dims, indices);
151 const float val = input[index];
152 const int32_t quantized_value =
153 static_cast<int32_t>(TfLiteRound(val * scales_inv[channel_idx]));
154 output_value->at(index) = std::min<int8_t>(
155 kMaxQuantizedValue,
156 std::max<int8_t>(kMinQuantizedValue, quantized_value));
157 }
158 }
159 }
160 }
161 }
162
SymmetricQuantizeTensor(ModelT * model,TensorT * tensor)163 TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
164 if (model == nullptr || tensor == nullptr) {
165 return kTfLiteError;
166 }
167
168 BufferT* buffer = model->buffers[tensor->buffer].get();
169 if (buffer == nullptr) {
170 return kTfLiteError;
171 }
172 float* float_data = reinterpret_cast<float*>(buffer->data.data());
173 uint64_t num_elements;
174 TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements));
175
176 std::vector<int8_t> quantized_buffer;
177 quantized_buffer.resize(num_elements);
178
179 float min_value, max_value, scaling_factor;
180 tensor_utils::SymmetricQuantizeFloats(float_data, num_elements,
181 quantized_buffer.data(), &min_value,
182 &max_value, &scaling_factor);
183
184 if (tensor->quantization == nullptr) {
185 tensor->quantization = absl::make_unique<QuantizationParametersT>();
186 }
187 tensor->quantization->scale = std::vector<float>(1, scaling_factor);
188 tensor->quantization->zero_point = std::vector<int64_t>(1, 0);
189
190 uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(quantized_buffer.data());
191 model->buffers[tensor->buffer]->data.assign(uint8_buffer,
192 uint8_buffer + num_elements);
193
194 // Update the tensor type.
195 tensor->type = TensorType_INT8;
196
197 return kTfLiteOk;
198 }
199
200 } // namespace utils
201 } // namespace optimize
202 } // namespace tflite
203