1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
21 #include "tensorflow/lite/toco/model.h"
22 #include "tensorflow/lite/toco/tooling_util.h"
23 #include "tensorflow/core/platform/logging.h"
24
25 namespace toco {
26
27 namespace {
28
HardcodeMinMaxForIm2colArray(Model * model,Operator * op)29 bool HardcodeMinMaxForIm2colArray(Model* model, Operator* op) {
30 if (op->outputs.size() != 2) {
31 return false;
32 }
33 auto& im2col_array = model->GetArray(op->outputs[1]);
34 if (im2col_array.minmax) {
35 return false;
36 }
37 const auto& input_array = model->GetArray(op->inputs[0]);
38 if (!input_array.minmax) {
39 return false;
40 }
41 const auto& input_minmax = input_array.GetMinMax();
42 CHECK(!im2col_array.minmax);
43 auto& im2col_minmax = im2col_array.GetOrCreateMinMax();
44 im2col_minmax.min = input_minmax.min;
45 im2col_minmax.max = input_minmax.max;
46 return true;
47 }
48
HardcodeMinMaxForL2Normalization(Model * model,Operator * op)49 bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) {
50 auto& output_array = model->GetArray(op->outputs[0]);
51 if (output_array.minmax) {
52 return false;
53 }
54 const auto& input_array = model->GetArray(op->inputs[0]);
55 if (!input_array.minmax) {
56 return false;
57 }
58 const auto& input_minmax = input_array.GetMinMax();
59 CHECK(!output_array.minmax);
60 auto& output_minmax = output_array.GetOrCreateMinMax();
61 output_minmax.min = input_minmax.min >= 0. ? 0. : -1.;
62 output_minmax.max = input_minmax.max <= 0. ? 0. : 1.;
63 return true;
64 }
65
HardcodeInputMinMaxFromOutput(Model * model,Operator * op)66 bool HardcodeInputMinMaxFromOutput(Model* model, Operator* op) {
67 auto& input = model->GetArray(op->inputs[0]);
68 if (input.minmax) {
69 const auto* minmax = input.minmax.get();
70 if (minmax) {
71 return false;
72 }
73 }
74 auto& output = model->GetArray(op->outputs[0]);
75 if (output.minmax) {
76 const auto* minmax = model->GetArray(op->outputs[0]).minmax.get();
77 if (minmax) {
78 input.GetOrCreateMinMax() = *minmax;
79 return true;
80 }
81 }
82 return false;
83 }
84
HardcodeMinMaxForConcatenation(Model * model,Operator * op)85 bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
86 // Do not early return if the output already has min/max:
87 // we may still need to adjust the inputs min/max.
88 bool has_minmax = false;
89 double overall_min = std::numeric_limits<double>::infinity();
90 double overall_max = -std::numeric_limits<double>::infinity();
91 for (const auto& input : op->inputs) {
92 if (model->GetArray(input).minmax) {
93 has_minmax = true;
94 const auto* minmax = model->GetArray(input).minmax.get();
95 if (minmax) {
96 overall_min = std::min(overall_min, minmax->min);
97 overall_max = std::max(overall_max, minmax->max);
98 }
99 }
100 }
101 auto& output = model->GetArray(op->outputs[0]);
102 if (output.minmax) {
103 has_minmax = true;
104 const auto* minmax = model->GetArray(op->outputs[0]).minmax.get();
105 if (minmax) {
106 overall_min = std::min(overall_min, minmax->min);
107 overall_max = std::max(overall_max, minmax->max);
108 }
109 }
110 if (!has_minmax) {
111 return false;
112 }
113 MinMax overall_minmax;
114 overall_minmax.min = overall_min;
115 overall_minmax.max = overall_max;
116 bool changed = false;
117 if (model->flags.change_concat_input_ranges()) {
118 for (const auto& input : op->inputs) {
119 auto& array = model->GetArray(input);
120 if (!array.minmax) {
121 changed = true;
122 } else if (!(overall_minmax == array.GetMinMax())) {
123 changed = true;
124 LOG(WARNING)
125 << "Tweaking the MinMax of array " << input << ", which is "
126 << "an input to " << LogName(*op) << ", because we want all inputs "
127 << "and outputs of a Concatenation operator to have the same "
128 << "MinMax so that it can be implemented as a pure byte-copy, no "
129 "arithmetic.";
130 }
131 array.GetOrCreateMinMax() = overall_minmax;
132 }
133 }
134 if (!output.minmax) {
135 changed = true;
136 } else if (!(overall_minmax == output.GetMinMax())) {
137 if (model->flags.change_concat_input_ranges()) {
138 changed = true;
139 LOG(WARNING)
140 << "Tweaking the MinMax of the output array of " << LogName(*op)
141 << ", because we want all inputs "
142 << "and outputs of a Concatenation operator to have the same MinMax "
143 << "so that it can be implemented as a pure byte-copy, no "
144 << "arithmetic.";
145 } else {
146 return false;
147 }
148 }
149 output.GetOrCreateMinMax() = overall_minmax;
150
151 return changed;
152 }
153
HardcodeMinMaxForSplit(Model * model,Operator * op)154 bool HardcodeMinMaxForSplit(Model* model, Operator* op) {
155 // Data is in second input.
156 auto& input_array = model->GetArray(op->inputs[1]);
157 if (!input_array.minmax) {
158 return false;
159 }
160 bool changed = false;
161 for (const auto& output : op->outputs) {
162 auto& array = model->GetArray(output);
163 if (!array.minmax || !(array.GetMinMax() == input_array.GetMinMax())) {
164 changed = true;
165 array.GetOrCreateMinMax() = *input_array.minmax;
166 }
167 }
168 return changed;
169 }
170
171 // The output of average or max pooling is within the same range as its input.
HardcodeMinMaxForAverageOrMaxPool(Model * model,Operator * op)172 bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) {
173 auto& output_array = model->GetArray(op->outputs[0]);
174 if (output_array.minmax) {
175 return false;
176 }
177 const auto& input_array = model->GetArray(op->inputs[0]);
178 if (!input_array.minmax) {
179 return false;
180 }
181 const auto& input_minmax = input_array.GetMinMax();
182 CHECK(!output_array.minmax);
183 auto& output_minmax = output_array.GetOrCreateMinMax();
184 output_minmax.min = std::min(input_minmax.min, 0.);
185 output_minmax.max = std::max(input_minmax.max, 0.);
186 return true;
187 }
188
HardcodeMinMaxFromFirstInput(Model * model,Operator * op)189 bool HardcodeMinMaxFromFirstInput(Model* model, Operator* op) {
190 auto& output_array = model->GetArray(op->outputs[0]);
191 if (output_array.minmax) {
192 return false;
193 }
194 const auto& input_array = model->GetArray(op->inputs[0]);
195 if (!input_array.minmax) {
196 return false;
197 }
198 const auto& input_minmax = input_array.GetMinMax();
199 CHECK(!output_array.minmax);
200 auto& output_minmax = output_array.GetOrCreateMinMax();
201 output_minmax.min = input_minmax.min;
202 output_minmax.max = input_minmax.max;
203 return true;
204 }
205
HardcodeMinMaxForSelect(Model * model,Operator * op)206 bool HardcodeMinMaxForSelect(Model* model, Operator* op) {
207 auto& output_array = model->GetArray(op->outputs[0]);
208 if (output_array.minmax) {
209 return false;
210 }
211
212 auto& input_array_1 = model->GetArray(op->inputs[1]);
213 auto& input_array_2 = model->GetArray(op->inputs[2]);
214
215 if (!input_array_1.minmax && !input_array_2.minmax) {
216 return false;
217 }
218
219 // Propagate up if one input is quantized and the other is constant.
220 if (!input_array_1.minmax &&
221 IsConstantParameterArray(*model, op->inputs[1])) {
222 auto& minmax_1 = input_array_1.GetOrCreateMinMax();
223 const auto& minmax_2 = input_array_2.GetMinMax();
224 minmax_1.min = minmax_2.min;
225 minmax_1.max = minmax_2.max;
226 }
227
228 if (!input_array_2.minmax &&
229 IsConstantParameterArray(*model, op->inputs[2])) {
230 auto& minmax_2 = input_array_2.GetOrCreateMinMax();
231 const auto& minmax_1 = input_array_1.GetMinMax();
232 minmax_2.min = minmax_1.min;
233 minmax_2.max = minmax_1.max;
234 }
235
236 if (!input_array_1.minmax || !input_array_2.minmax) {
237 return false;
238 }
239
240 const auto& input_minmax_1 = input_array_1.GetMinMax();
241 const auto& input_minmax_2 = input_array_2.GetMinMax();
242
243 CHECK_EQ(input_minmax_1.min, input_minmax_2.min);
244 CHECK_EQ(input_minmax_1.max, input_minmax_2.max);
245 CHECK(!output_array.minmax);
246 auto& output_minmax = output_array.GetOrCreateMinMax();
247 output_minmax.min = input_minmax_1.min;
248 output_minmax.max = input_minmax_1.max;
249 return true;
250 }
251
HardcodeMinMaxForOutput(Model * model,Operator * op,double min,double max)252 bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min,
253 double max) {
254 CHECK_EQ(op->outputs.size(), 1);
255 auto& output_array = model->GetArray(op->outputs[0]);
256 if (output_array.minmax) {
257 return false;
258 }
259 const auto& input_array = model->GetArray(op->inputs[0]);
260 if (!input_array.minmax) {
261 return false;
262 }
263 CHECK(!output_array.minmax);
264 auto& output_minmax = output_array.GetOrCreateMinMax();
265 output_minmax.min = min;
266 output_minmax.max = max;
267 return true;
268 }
269
MinMaxApproximatelyEqual(const MinMax & minmax1,const MinMax & minmax2)270 bool MinMaxApproximatelyEqual(const MinMax& minmax1, const MinMax& minmax2) {
271 const double magnitude =
272 std::min(minmax1.max - minmax1.min, minmax2.max - minmax2.min);
273 const double tolerated = 1e-6 * magnitude;
274 return std::abs(minmax1.min - minmax2.min) < tolerated &&
275 std::abs(minmax1.max - minmax2.max) < tolerated;
276 }
277
278 // Propagates MinMax from any of the listed arrays, to all others.
279 // If multiple of these arrays have MinMax, then these are required
280 // to agree with each other.
PropagateMinMaxAmongArrays(Model * model,const std::vector<string> array_names)281 bool PropagateMinMaxAmongArrays(Model* model,
282 const std::vector<string> array_names) {
283 string reference_array_name;
284 MinMax* reference_minmax = nullptr;
285 for (const string& array_name : array_names) {
286 if (model->GetArray(array_name).minmax) {
287 reference_array_name = array_name;
288 reference_minmax = model->GetArray(array_name).minmax.get();
289 break;
290 }
291 }
292 // No MinMax info is available to propagate.
293 if (!reference_minmax) {
294 return false;
295 }
296 bool changed = false;
297 for (const string& array_name : array_names) {
298 auto& array = model->GetArray(array_name);
299 if (array.minmax) {
300 CHECK(MinMaxApproximatelyEqual(*array.minmax, *reference_minmax))
301 << "Both the following arrays have minmax, and they disagree: "
302 << reference_array_name << " (" << reference_minmax->min << ","
303 << reference_minmax->max << ") and " << array_name << " ("
304 << array.minmax->min << "," << array.minmax->max
305 << "). Expected that either only one of them would have minmax, or "
306 "at "
307 "least that they would agree.";
308 } else {
309 array.GetOrCreateMinMax() = *reference_minmax;
310 changed = true;
311 }
312 }
313 return changed;
314 }
315
HardcodeMinMaxForReshape(Model * model,Operator * op)316 bool HardcodeMinMaxForReshape(Model* model, Operator* op) {
317 Array& input = model->GetArray(op->inputs[0]);
318 Array& output = model->GetArray(op->outputs[0]);
319
320 // If input and output both exist or do not exist, do nothing.
321 if ((!input.minmax && !output.minmax) || (input.minmax && output.minmax)) {
322 return false;
323 }
324
325 // Otherwise propagate info amongst the input and output array.
326 return PropagateMinMaxAmongArrays(model, {op->inputs[0], op->outputs[0]});
327 }
328
HardcodeMinMaxForLstmCell(Model * model,Operator * op)329 bool HardcodeMinMaxForLstmCell(Model* model, Operator* op) {
330 CHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS);
331 CHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);
332
333 bool changed = false;
334 changed |= PropagateMinMaxAmongArrays(
335 model, {op->inputs[LstmCellOperator::PREV_STATE_INPUT],
336 op->outputs[LstmCellOperator::STATE_OUTPUT]});
337
338 auto& input_activations =
339 model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
340 if (!input_activations.minmax) {
341 auto& minmax = input_activations.GetOrCreateMinMax();
342 minmax.min = -1;
343 minmax.max = 127. / 128.;
344 changed = true;
345 }
346
347 auto& prev_output_activations =
348 model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
349 if (!prev_output_activations.minmax) {
350 auto& minmax = prev_output_activations.GetOrCreateMinMax();
351 minmax.min = -1;
352 minmax.max = 127. / 128.;
353 changed = true;
354 }
355
356 auto& output_concat_temp =
357 model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP]);
358 if (!output_concat_temp.minmax) {
359 auto& minmax = output_concat_temp.GetOrCreateMinMax();
360 minmax.min = -1;
361 minmax.max = 127. / 128.;
362 changed = true;
363 }
364
365 auto& output_activations =
366 model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT]);
367 if (!output_activations.minmax) {
368 auto& minmax = output_activations.GetOrCreateMinMax();
369 minmax.min = -1;
370 minmax.max = 127. / 128.;
371 changed = true;
372 }
373
374 // (This comment should morph into proper documentation for
375 // quantization of LSTM models. It isn't just a local implementation detail,
376 // the training code for LSTM models needs to be adjusted to that.)
377 //
378 // Finally, output_activations_temp holds the output of the fully-connected
379 // node inside the LSTM cell. For it, we hardcode a minmax of [-8, 8].
380 // The rationale for that is given in a lengthy comment on the LstmCell
381 // quantized runtime implementation in reference_ops.h.
382 auto& output_activations_temp =
383 model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP]);
384 if (!output_activations_temp.minmax) {
385 auto& minmax = output_activations_temp.GetOrCreateMinMax();
386 minmax.min = -8;
387 minmax.max = 8 * 32767. / 32768.;
388 changed = true;
389 }
390
391 return changed;
392 }
393 } // namespace
394
Run(Model * model,std::size_t op_index,bool * modified)395 ::tensorflow::Status HardcodeMinMax::Run(Model* model, std::size_t op_index,
396 bool* modified) {
397 *modified = false;
398 auto it = model->operators.begin() + op_index;
399 auto* op = it->get();
400 bool changed = false;
401 switch (op->type) {
402 case OperatorType::kConv:
403 changed = HardcodeMinMaxForIm2colArray(model, op);
404 break;
405
406 case OperatorType::kL2Normalization:
407 changed = HardcodeMinMaxForL2Normalization(model, op);
408 break;
409
410 case OperatorType::kRelu:
411 // For any normalization other than batch norm, the quantizations ranges
412 // before and after relu are expected to be known. Having a quantization
413 // op before relu would reduce the number of bits of precision for the
414 // activation in half. So we deduce the range before relu from that after
415 // the relu. This would eliminate the need for two fake quantization nodes
416 // and would not reduce the bits of precision available for activation.
417 changed = HardcodeInputMinMaxFromOutput(model, op);
418 break;
419
420 case OperatorType::kConcatenation:
421 changed = HardcodeMinMaxForConcatenation(model, op);
422 break;
423
424 case OperatorType::kSplit:
425 changed = HardcodeMinMaxForSplit(model, op);
426 break;
427
428 case OperatorType::kAveragePool:
429 case OperatorType::kMaxPool:
430 changed = HardcodeMinMaxForAverageOrMaxPool(model, op);
431 break;
432
433 case OperatorType::kResizeBilinear:
434 case OperatorType::kResizeNearestNeighbor:
435 case OperatorType::kSlice:
436 case OperatorType::kStridedSlice:
437 case OperatorType::kSqueeze:
438 case OperatorType::kExpandDims:
439 case OperatorType::kPad:
440 case OperatorType::kGather:
441 case OperatorType::kTranspose:
442 case OperatorType::kMean:
443 case OperatorType::kReduceMax:
444 case OperatorType::kReduceMin:
445 changed = HardcodeMinMaxFromFirstInput(model, op);
446 break;
447 case OperatorType::kSum:
448 // reduce_sum is expected to change the output range. Hence
449 // a fake_quant op is necessary in the output to minimize error. However
450 // in special circumstances like when computing expected value using
451 // reduce_sum the input range and the output range matches. Hence the
452 // below code would act as a fallback. If a fake_quant node is observed in
453 // the output that takes precedence over the hard coding logic below.
454 changed = HardcodeMinMaxFromFirstInput(model, op);
455 if (changed) {
456 LOG(WARNING) << "Using the input range for output in reduce_sum op."
457 << "This could have an impact on your model accuracy.";
458 }
459 break;
460 case OperatorType::kSelect:
461 changed = HardcodeMinMaxForSelect(model, op);
462 break;
463 case OperatorType::kLogistic:
464 // We hardcode quantization_params to: zero_point=0, scale=1/256.
465 // This choice of minmax is the one that is equivalent to that.
466 changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.);
467 break;
468
469 case OperatorType::kSoftmax:
470 // We hardcode quantization_params to: zero_point=0, scale=1/256.
471 // This choice of minmax is the one that is equivalent to that.
472 changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.);
473 break;
474
475 case OperatorType::kTanh:
476 // We hardcode quantization_params to: zero_point=127, scale=1/128.
477 // This choice of minmax is the one that is equivalent to that.
478 changed = HardcodeMinMaxForOutput(model, op, -127. / 128., 1.0);
479 break;
480
481 case OperatorType::kLstmCell:
482 changed = HardcodeMinMaxForLstmCell(model, op);
483 break;
484
485 case OperatorType::kReshape:
486 changed = HardcodeMinMaxForReshape(model, op);
487 break;
488
489 default:
490 break;
491 }
492 if (changed) {
493 AddMessageF("Hardcoded min-max through %s", LogName(*op));
494 }
495 *modified = changed;
496 return ::tensorflow::Status::OK();
497 }
498
499 } // namespace toco
500