1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/toco_tooling.h"
16
17 #include <cstdlib>
18 #include <memory>
19 #include <set>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/lite/toco/allocate_transient_arrays.h"
24 #include "tensorflow/lite/toco/dump_graphviz.h"
25 #include "tensorflow/lite/toco/export_tensorflow.h"
26 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
27 #include "tensorflow/lite/toco/import_tensorflow.h"
28 #include "tensorflow/lite/toco/model_flags.pb.h"
29 #include "tensorflow/lite/toco/tflite/export.h"
30 #include "tensorflow/lite/toco/tflite/import.h"
31 #include "tensorflow/lite/toco/toco_flags.pb.h"
32 #include "tensorflow/lite/toco/tooling_util.h"
33 #include "tensorflow/core/platform/logging.h"
34
35 namespace toco {
36 namespace {
37 // CHECK-fails if the model contains a kUnsupported operation.
CheckUnsupportedOperations(const Model & model)38 void CheckUnsupportedOperations(const Model& model) {
39 std::set<string> unsupported_ops;
40 for (auto& op : model.operators) {
41 if (op->type == OperatorType::kUnsupported) {
42 unsupported_ops.insert(
43 static_cast<const TensorFlowUnsupportedOperator*>(op.get())
44 ->tensorflow_op);
45 }
46 }
47 QCHECK(unsupported_ops.empty())
48 << "These unsupported ops were not removed by graph transformations: "
49 << absl::StrJoin(unsupported_ops, ", ");
50 }
51
MakeGeneralGraphTransformationsSet(GraphTransformationsSet * transformations)52 void MakeGeneralGraphTransformationsSet(
53 GraphTransformationsSet* transformations) {
54 CHECK(transformations->empty());
55 transformations->Add(new ConvertExpandDimsToReshape);
56 transformations->Add(new ConvertSqueezeToReshape);
57 transformations->Add(new ConvertTrivialAddNToAdd);
58 transformations->Add(new ConvertTrivialPackToReshape);
59 transformations->Add(new ConvertTrivialTileToConcat);
60 transformations->Add(new ConvertTrivialTransposeToReshape);
61 transformations->Add(new ConvertReorderAxes);
62 transformations->Add(new ResolveReshapeAttributes);
63 transformations->Add(new ResolveTransposeAttributes);
64 transformations->Add(new PropagateActivationFunctionIntoConstants);
65 transformations->Add(new PropagateArrayDataTypes);
66 transformations->Add(new PropagateFixedSizes);
67 transformations->Add(new RemoveTensorFlowAssert);
68 transformations->Add(new RemoveTensorFlowIdentity);
69 transformations->Add(new RemoveTrivialConcatenation);
70 transformations->Add(new RemoveTrivialConcatenationInput);
71 transformations->Add(new RemoveTrivialFakeQuant);
72 transformations->Add(new RemoveTrivialSlice);
73 transformations->Add(new RemoveUnusedOp);
74 transformations->Add(new EnsureBiasVectors);
75 transformations->Add(new ResolveReorderAxes);
76 transformations->Add(new UnrollBatchMatMul);
77 transformations->Add(new ResolveTensorFlowMatMul);
78 transformations->Add(new FuseBinaryIntoPrecedingAffine);
79 transformations->Add(new FuseBinaryIntoFollowingAffine);
80 transformations->Add(new FuseBroadcastIntoFollowingBinary);
81 transformations->Add(new MergeReshapeIntoPrecedingTranspose);
82 transformations->Add(new MoveBinaryOperatorBeforeReshape);
83 transformations->Add(new ReorderElementwiseUnary);
84 transformations->Add(new ReorderReshapeTranspose);
85 transformations->Add(new ResolveBatchNormalization);
86 transformations->Add(new ResolveConstantBinaryOperator);
87 transformations->Add(new ResolveConstantFill);
88 transformations->Add(new ResolveConstantGather);
89 transformations->Add(new ResolveConstantPack);
90 transformations->Add(new ResolveConstantRandomUniform);
91 transformations->Add(new ResolveConstantRange);
92 transformations->Add(new ResolveConstantReshape);
93 transformations->Add(new ResolveConstantSelect);
94 transformations->Add(new ResolveConstantSlice);
95 transformations->Add(new ResolveConstantStridedSlice);
96 transformations->Add(new ResolveConstantTile);
97 transformations->Add(new ResolveConstantTranspose);
98 transformations->Add(new ResolveConstantUnaryOperator);
99 transformations->Add(new ResolveTensorFlowMerge);
100 transformations->Add(new ResolveSqueezeAttributes);
101 transformations->Add(new ResolveTensorFlowSwitch);
102 transformations->Add(new ResolveTensorFlowConcat);
103 transformations->Add(new ResolveMultiplyByZero);
104 transformations->Add(new IdentifyL2Normalization);
105 transformations->Add(new IdentifyL2Pool);
106 transformations->Add(new IdentifyRelu1);
107 transformations->Add(new IdentifyPRelu);
108 transformations->Add(new RemoveTrivialBinaryOperator);
109 transformations->Add(new ResolveFakeQuantArgsFromVars);
110 transformations->Add(new ReadArrayMinmaxAndNarrowRangeFromFakeQuant);
111 transformations->Add(new ResolveSpaceToBatchNDAttributes);
112 transformations->Add(new ResolveBatchToSpaceNDAttributes);
113 transformations->Add(new ResolvePadAttributes);
114 transformations->Add(new ResolvePadV2Attributes);
115 transformations->Add(new ResolveStridedSliceAttributes);
116 transformations->Add(new ResolveSliceAttributes);
117 transformations->Add(new ResolveReduceAttributes);
118 transformations->Add(new ResolveConstantShapeOrRank);
119 transformations->Add(new MakeInitialDequantizeOperator);
120 transformations->Add(new UnpartitionEmbeddingLookup);
121 transformations->Add(new ResolveGatherAttributes);
122 }
123
SupportsQuantization(FileFormat format)124 bool SupportsQuantization(FileFormat format) {
125 return (format == GRAPHVIZ_DOT || format == TFLITE);
126 }
127
SupportsFusedActivationFunction(FileFormat format)128 bool SupportsFusedActivationFunction(FileFormat format) {
129 return (format == GRAPHVIZ_DOT || format == TFLITE);
130 }
131
SupportsLstmCell(FileFormat format)132 bool SupportsLstmCell(FileFormat format) {
133 return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT ||
134 format == TFLITE);
135 }
136
SupportsPreallocatedWorkspace(FileFormat format)137 bool SupportsPreallocatedWorkspace(FileFormat format) {
138 return (format == TFLITE);
139 }
140
SupportsShuffledFCWeights(FileFormat format)141 bool SupportsShuffledFCWeights(FileFormat format) { return format == TFLITE; }
142
IsRealValued(toco::ArrayDataType type)143 bool IsRealValued(toco::ArrayDataType type) {
144 // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used
145 // for quantized real-number values, and no other integer type is ever used
146 // for that. This is dirty, should be resolved as part of a more general push
147 // to more explicitly distinguish between true-integers and
148 // integers used as quantized values representing real numbers.
149 return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
150 type == toco::ArrayDataType::kUint8 ||
151 type == toco::ArrayDataType::kInt16);
152 }
153
SetFinalDataTypeOnInputs(const TocoFlags & toco_flags,Model * model)154 void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
155 const FileFormat output_format = toco_flags.output_format();
156 ArrayDataType type;
157 if (!SupportsQuantization(output_format)) {
158 // Data type is implicitly float for non-quantized formats
159 type = ArrayDataType::kFloat;
160 } else if (toco_flags.has_inference_input_type()) {
161 type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
162 } else if (toco_flags.has_inference_type()) {
163 type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
164 } else {
165 // Nothing to do. Data types stay as-is.
166 return;
167 }
168
169 for (int i = 0; i < model->flags.input_arrays_size(); i++) {
170 string const& array_name = model->flags.input_arrays(i).name();
171 auto* array = &model->GetArray(array_name);
172 // Note that the notion of changing data types only applies to real-numbers
173 // arrays (see the documentation for inference_input_type).
174 // TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
175 // i.e. represent real numbers by means of quantization parameters,
176 // and not plain integer uint8 input arrays.
177 if (!IsRealValued(array->data_type)) {
178 // Ignore non-real data types.
179 continue;
180 }
181 // The enum value QUANTIZED_UINT8 for --inference_type and
182 // --inference_input_type has long meant just 'QUANTIZED', being used as
183 // well in mixed 8-bit / 16-bit quantized models. However,
184 // ConvertIODataTypeToArrayDataType still interpretes it as meaning 8bit,
185 // and people have run into issues in the situation where they have an
186 // already mixed 8-bit / 16-bit quantized model in TFLITE format and
187 // want to run it again through toco, without having to re-specify all the
188 // extra array info that was used in the (complicated) process of initially
189 // quantizing that model. In order to have --inference_type=QUANTIZED_UINT8
190 // just work in that case, we implement the logic that when an array is
191 // already quantized, if --inference_type is quantized (so we're not
192 // asking to dequantize here), no change of quantized data type is to be
193 // recorded.
194 if (array->data_type != toco::ArrayDataType::kFloat &&
195 type != toco::ArrayDataType::kFloat) {
196 continue;
197 }
198
199 array->final_data_type = type;
200 }
201 }
202
203 } // namespace
204
Import(const TocoFlags & toco_flags,const ModelFlags & model_flags,const string & input_file_contents)205 std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
206 const ModelFlags& model_flags,
207 const string& input_file_contents) {
208 std::unique_ptr<Model> model;
209 switch (toco_flags.input_format()) {
210 case TENSORFLOW_GRAPHDEF: {
211 TensorFlowImportFlags tf_import_flags;
212 tf_import_flags.drop_control_dependency =
213 toco_flags.has_drop_control_dependency()
214 ? toco_flags.drop_control_dependency()
215 : (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
216
217 tf_import_flags.import_all_ops_as_unsupported =
218 toco_flags.force_select_tf_ops();
219
220 model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
221 input_file_contents);
222 break;
223 }
224 case TFLITE:
225 model = toco::tflite::Import(model_flags, input_file_contents);
226 ResolveModelFlags(model_flags, model.get());
227 CheckInvariants(*model);
228 break;
229 default:
230 LOG(FATAL) << "Unhandled input_format='"
231 << FileFormat_Name(toco_flags.input_format()) << "'";
232 }
233
234 LogDump(kLogLevelModelChanged, "AT IMPORT", *model);
235
236 return model;
237 }
238
TransformWithStatus(const TocoFlags & toco_flags,Model * model)239 tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
240 Model* model) {
241 const FileFormat output_format = toco_flags.output_format();
242 const IODataType inference_type = toco_flags.inference_type();
243
244 const bool quantize_output =
245 SupportsQuantization(output_format) &&
246 (inference_type == QUANTIZED_UINT8 || inference_type == QUANTIZED_INT16);
247
248 if (quantize_output) {
249 QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
250 << "Quantized inference is not allowed with float inputs.";
251 }
252
253 // Clean up after import.
254 SetFinalDataTypeOnInputs(toco_flags, model);
255 UseArraysExtraInfo(model, quantize_output);
256 FinishBuildingRNNStates(model);
257
258 // Remove unused ops before performing any other optimizations. This is to
259 // stop optimizations from crossing the input/output boundaries. For example
260 // this will stop BatchNorm fusing if the output node is in between a conv
261 // and BatchNorm layers.
262 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
263 model, "Removing unused ops", {new toco::RemoveUnusedOp}));
264
265 GraphTransformationsSet transformations;
266 MakeGeneralGraphTransformationsSet(&transformations);
267 auto* remove_trivial_reshape = new RemoveTrivialReshape;
268 transformations.Add(remove_trivial_reshape);
269 auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant;
270 if (quantize_output) {
271 resolve_constant_fake_quant->set_propagate_fake_quant_num_bits(
272 toco_flags.propagate_fake_quant_num_bits());
273 }
274 transformations.Add(resolve_constant_fake_quant);
275 if (SupportsFusedActivationFunction(output_format)) {
276 transformations.Add(new FuseActivationFunctions);
277 } else {
278 transformations.Add(new UnfuseActivationFunctions);
279 }
280 if (toco_flags.drop_fake_quant()) {
281 transformations.Add(new DropFakeQuant);
282 } else {
283 // See the doc for --reorder_across_fake_quant: that flag is needed to
284 // support some existing models, e.g. WordLens, that have FakeQuant
285 // nodes in the wrong places.
286 // TODO(benoitjacob): drop special casing when we can.
287 if ((quantize_output && toco_flags.reorder_across_fake_quant())) {
288 transformations.Add(new DropFakeQuant);
289 }
290 }
291 transformations.Add(new ConvertPureConvToDepthwise);
292 if (SupportsLstmCell(output_format)) {
293 if (!toco_flags.debug_disable_recurrent_cell_fusion()) {
294 transformations.Add(new IdentifyLstmCell);
295 }
296 if (output_format == TFLITE && toco_flags.split_tflite_lstm_inputs()) {
297 transformations.Add(new toco::SplitLstmCellInputs);
298 } else {
299 transformations.Add(new toco::MergeLstmCellInputs);
300 }
301 }
302 transformations.Add(new ResolveConstantConcatenation);
303 // TODO(b/116063589): TF GraphDef doesn't support dilations on its depthwise
304 // conv, so we need to make sure we don't convert to dilated depthwise conv
305 // when outputing to TF GraphDef.
306 auto* identify_dilated_conv = new IdentifyDilatedConv;
307 if (output_format == TENSORFLOW_GRAPHDEF) {
308 identify_dilated_conv->set_identify_depthwise_conv(false);
309 }
310 transformations.Add(identify_dilated_conv);
311 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
312 model, "general graph transformations", transformations));
313
314 if (quantize_output) {
315 if (toco_flags.propagate_fake_quant_num_bits()) {
316 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
317 model, "fake quant propagation graph transformations",
318 {new PropagateFakeQuantNumBits}));
319 }
320 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
321 model, "pre-quantization graph transformations",
322 {
323 new HardcodeMinMax,
324 new DropFakeQuant,
325 }));
326 }
327
328 // Try to merge bidirectional sequence lstm or rnn if present.
329 GraphTransformationsSet bidirectional_transformations;
330 bidirectional_transformations.Add(new RemoveUnusedOp);
331 bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm);
332 bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn);
333 bidirectional_transformations.Add(
334 new toco::GroupDynamicBidirectionalSequenceRnn);
335 bidirectional_transformations.Add(
336 new toco::GroupDynamicBidirectionalSequenceLstm);
337 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
338 model, "Group bidirectional sequence lstm/rnn",
339 bidirectional_transformations));
340
341 // Fix any issues with IO edges. This must happen after any transform that
342 // may modify the structure of the edges.
343 FixEdgeArrays(model);
344 FixOperatorOrdering(model);
345
346 if (quantize_output) {
347 // If the user specified default min/max ranges we need to set all arrays
348 // that didn't either have a min/max specified or get one set via
349 // HardcodeMinMax or PropagateFakeQuantNumBits. This may require running
350 // HardcodeMinMax to move changes through the graph as we make changes.
351 auto propagate_default_min_max =
352 absl::make_unique<PropagateDefaultMinMax>();
353 bool has_default_ranges_flag = (toco_flags.has_default_ranges_min() &&
354 toco_flags.has_default_ranges_max());
355 if (has_default_ranges_flag) {
356 propagate_default_min_max->DefineTypeRange(
357 ArrayDataType::kUint8, toco_flags.default_ranges_min(),
358 toco_flags.default_ranges_max());
359 }
360 if (toco_flags.has_default_int16_ranges_min() &&
361 toco_flags.has_default_int16_ranges_max()) {
362 propagate_default_min_max->DefineTypeRange(
363 ArrayDataType::kInt16, toco_flags.default_int16_ranges_min(),
364 toco_flags.default_int16_ranges_max());
365 }
366 if (propagate_default_min_max->has_any_ranges_defined()) {
367 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
368 model, "default min-max range propagation graph transformations",
369 {
370 propagate_default_min_max.release(),
371 new HardcodeMinMax,
372 }));
373 }
374
375 CheckIsReadyForQuantization(*model);
376 auto* ensure_safe_for_int8_kernels =
377 new EnsureUint8WeightsSafeForFastInt8Kernels;
378 ensure_safe_for_int8_kernels->set_allow_nudging_weights(
379 toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
380 ensure_safe_for_int8_kernels->set_has_default_ranges_flag(
381 has_default_ranges_flag);
382 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
383 model, "quantization graph transformations",
384 {
385 new RemoveTrivialQuantizedActivationFunc,
386 new RemoveTrivialQuantizedMinMax,
387 new Quantize,
388 new RemoveFinalDequantizeOp,
389 ensure_safe_for_int8_kernels,
390 }));
391 if (SupportsShuffledFCWeights(output_format)) {
392 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
393 model, "shuffling of FC weights", {new ShuffleFCWeights}));
394 }
395 } else {
396 GraphTransformationsSet dequantization_transformations{new Dequantize};
397 // Dequantize creates FakeQuant nodes. We may want to discard
398 // those immediately.
399 if (toco_flags.drop_fake_quant()) {
400 dequantization_transformations.Add(new DropFakeQuant);
401 }
402
403 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
404 model, "dequantization graph transformations",
405 dequantization_transformations));
406 }
407
408 if (output_format == TENSORFLOW_GRAPHDEF) {
409 EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model);
410 }
411
412 // Deduplicate large constant arrays.
413 DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes());
414
415 LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
416
417 if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) {
418 // By now there shouldn't be any unsupported ops when exporting to
419 // TensorFlow GraphDef.
420 CheckUnsupportedOperations(*model);
421 }
422
423 if (SupportsPreallocatedWorkspace(output_format)) {
424 AllocateTransientArrays(model, kDefaultTransientDataAlignment);
425 LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
426 }
427
428 CheckModelCounts(*model);
429 CheckFinalDataTypesSatisfied(*model);
430
431 int64 ops_count;
432 if (EstimateArithmeticOpsCount(*model, &ops_count)) {
433 LOG(INFO) << "Estimated count of arithmetic ops: " << 1e-9 * ops_count
434 << " billion (note that a multiply-add is counted as 2 ops).";
435 }
436 model->ops_count = ops_count;
437 return tensorflow::Status::OK();
438 }
439
Export(const TocoFlags & toco_flags,const Model & model,bool allow_custom_ops,string * output_file_contents)440 tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
441 bool allow_custom_ops, string* output_file_contents) {
442 switch (toco_flags.output_format()) {
443 case TENSORFLOW_GRAPHDEF:
444 ExportTensorFlowGraphDef(model, output_file_contents);
445 break;
446 case TFLITE: {
447 toco::tflite::ExportParams params;
448
449 params.enable_select_tf_ops =
450 toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops();
451 params.allow_custom_ops = allow_custom_ops;
452 params.quantize_weights = toco_flags.post_training_quantize();
453
454 auto status = toco::tflite::Export(model, output_file_contents, params);
455 if (!status.ok()) {
456 LOG(ERROR) << status.error_message();
457 }
458 return status;
459 } break;
460 case GRAPHVIZ_DOT:
461 DumpGraphviz(model, output_file_contents, "Computation Graph");
462 break;
463 default:
464 LOG(FATAL) << "Unhandled output_format='"
465 << FileFormat_Name(toco_flags.output_format()) << "'";
466 }
467 return tensorflow::Status();
468 }
469
470 } // namespace toco
471