1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/toco_tooling.h"
16 
17 #include <cstdlib>
18 #include <memory>
19 #include <set>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/lite/toco/allocate_transient_arrays.h"
24 #include "tensorflow/lite/toco/dump_graphviz.h"
25 #include "tensorflow/lite/toco/export_tensorflow.h"
26 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
27 #include "tensorflow/lite/toco/import_tensorflow.h"
28 #include "tensorflow/lite/toco/model_flags.pb.h"
29 #include "tensorflow/lite/toco/tflite/export.h"
30 #include "tensorflow/lite/toco/tflite/import.h"
31 #include "tensorflow/lite/toco/toco_flags.pb.h"
32 #include "tensorflow/lite/toco/tooling_util.h"
33 #include "tensorflow/core/platform/logging.h"
34 
35 namespace toco {
36 namespace {
37 // CHECK-fails if the model contains a kUnsupported operation.
CheckUnsupportedOperations(const Model & model)38 void CheckUnsupportedOperations(const Model& model) {
39   std::set<string> unsupported_ops;
40   for (auto& op : model.operators) {
41     if (op->type == OperatorType::kUnsupported) {
42       unsupported_ops.insert(
43           static_cast<const TensorFlowUnsupportedOperator*>(op.get())
44               ->tensorflow_op);
45     }
46   }
47   QCHECK(unsupported_ops.empty())
48       << "These unsupported ops were not removed by graph transformations: "
49       << absl::StrJoin(unsupported_ops, ", ");
50 }
51 
MakeGeneralGraphTransformationsSet(GraphTransformationsSet * transformations)52 void MakeGeneralGraphTransformationsSet(
53     GraphTransformationsSet* transformations) {
54   CHECK(transformations->empty());
55   transformations->Add(new ConvertExpandDimsToReshape);
56   transformations->Add(new ConvertSqueezeToReshape);
57   transformations->Add(new ConvertTrivialAddNToAdd);
58   transformations->Add(new ConvertTrivialPackToReshape);
59   transformations->Add(new ConvertTrivialTileToConcat);
60   transformations->Add(new ConvertTrivialTransposeToReshape);
61   transformations->Add(new ConvertReorderAxes);
62   transformations->Add(new ResolveReshapeAttributes);
63   transformations->Add(new ResolveTransposeAttributes);
64   transformations->Add(new PropagateActivationFunctionIntoConstants);
65   transformations->Add(new PropagateArrayDataTypes);
66   transformations->Add(new PropagateFixedSizes);
67   transformations->Add(new RemoveTensorFlowAssert);
68   transformations->Add(new RemoveTensorFlowIdentity);
69   transformations->Add(new RemoveTrivialConcatenation);
70   transformations->Add(new RemoveTrivialConcatenationInput);
71   transformations->Add(new RemoveTrivialFakeQuant);
72   transformations->Add(new RemoveTrivialSlice);
73   transformations->Add(new RemoveUnusedOp);
74   transformations->Add(new EnsureBiasVectors);
75   transformations->Add(new ResolveReorderAxes);
76   transformations->Add(new UnrollBatchMatMul);
77   transformations->Add(new ResolveTensorFlowMatMul);
78   transformations->Add(new FuseBinaryIntoPrecedingAffine);
79   transformations->Add(new FuseBinaryIntoFollowingAffine);
80   transformations->Add(new FuseBroadcastIntoFollowingBinary);
81   transformations->Add(new MergeReshapeIntoPrecedingTranspose);
82   transformations->Add(new MoveBinaryOperatorBeforeReshape);
83   transformations->Add(new ReorderElementwiseUnary);
84   transformations->Add(new ReorderReshapeTranspose);
85   transformations->Add(new ResolveBatchNormalization);
86   transformations->Add(new ResolveConstantBinaryOperator);
87   transformations->Add(new ResolveConstantFill);
88   transformations->Add(new ResolveConstantGather);
89   transformations->Add(new ResolveConstantPack);
90   transformations->Add(new ResolveConstantRandomUniform);
91   transformations->Add(new ResolveConstantRange);
92   transformations->Add(new ResolveConstantReshape);
93   transformations->Add(new ResolveConstantSelect);
94   transformations->Add(new ResolveConstantSlice);
95   transformations->Add(new ResolveConstantStridedSlice);
96   transformations->Add(new ResolveConstantTile);
97   transformations->Add(new ResolveConstantTranspose);
98   transformations->Add(new ResolveConstantUnaryOperator);
99   transformations->Add(new ResolveTensorFlowMerge);
100   transformations->Add(new ResolveSqueezeAttributes);
101   transformations->Add(new ResolveTensorFlowSwitch);
102   transformations->Add(new ResolveTensorFlowConcat);
103   transformations->Add(new ResolveMultiplyByZero);
104   transformations->Add(new IdentifyL2Normalization);
105   transformations->Add(new IdentifyL2Pool);
106   transformations->Add(new IdentifyRelu1);
107   transformations->Add(new IdentifyPRelu);
108   transformations->Add(new RemoveTrivialBinaryOperator);
109   transformations->Add(new ResolveFakeQuantArgsFromVars);
110   transformations->Add(new ReadArrayMinmaxAndNarrowRangeFromFakeQuant);
111   transformations->Add(new ResolveSpaceToBatchNDAttributes);
112   transformations->Add(new ResolveBatchToSpaceNDAttributes);
113   transformations->Add(new ResolvePadAttributes);
114   transformations->Add(new ResolvePadV2Attributes);
115   transformations->Add(new ResolveStridedSliceAttributes);
116   transformations->Add(new ResolveSliceAttributes);
117   transformations->Add(new ResolveReduceAttributes);
118   transformations->Add(new ResolveConstantShapeOrRank);
119   transformations->Add(new MakeInitialDequantizeOperator);
120   transformations->Add(new UnpartitionEmbeddingLookup);
121   transformations->Add(new ResolveGatherAttributes);
122 }
123 
SupportsQuantization(FileFormat format)124 bool SupportsQuantization(FileFormat format) {
125   return (format == GRAPHVIZ_DOT || format == TFLITE);
126 }
127 
SupportsFusedActivationFunction(FileFormat format)128 bool SupportsFusedActivationFunction(FileFormat format) {
129   return (format == GRAPHVIZ_DOT || format == TFLITE);
130 }
131 
SupportsLstmCell(FileFormat format)132 bool SupportsLstmCell(FileFormat format) {
133   return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT ||
134           format == TFLITE);
135 }
136 
SupportsPreallocatedWorkspace(FileFormat format)137 bool SupportsPreallocatedWorkspace(FileFormat format) {
138   return (format == TFLITE);
139 }
140 
SupportsShuffledFCWeights(FileFormat format)141 bool SupportsShuffledFCWeights(FileFormat format) { return format == TFLITE; }
142 
IsRealValued(toco::ArrayDataType type)143 bool IsRealValued(toco::ArrayDataType type) {
144   // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used
145   // for quantized real-number values, and no other integer type is ever used
146   // for that. This is dirty, should be resolved as part of a more general push
147   // to more explicitly distinguish between true-integers and
148   // integers used as quantized values representing real numbers.
149   return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
150                            type == toco::ArrayDataType::kUint8 ||
151                            type == toco::ArrayDataType::kInt16);
152 }
153 
SetFinalDataTypeOnInputs(const TocoFlags & toco_flags,Model * model)154 void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
155   const FileFormat output_format = toco_flags.output_format();
156   ArrayDataType type;
157   if (!SupportsQuantization(output_format)) {
158     // Data type is implicitly float for non-quantized formats
159     type = ArrayDataType::kFloat;
160   } else if (toco_flags.has_inference_input_type()) {
161     type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
162   } else if (toco_flags.has_inference_type()) {
163     type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
164   } else {
165     // Nothing to do. Data types stay as-is.
166     return;
167   }
168 
169   for (int i = 0; i < model->flags.input_arrays_size(); i++) {
170     string const& array_name = model->flags.input_arrays(i).name();
171     auto* array = &model->GetArray(array_name);
172     // Note that the notion of changing data types only applies to real-numbers
173     // arrays (see the documentation for inference_input_type).
174     // TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
175     // i.e. represent real numbers by means of quantization parameters,
176     // and not plain integer uint8 input arrays.
177     if (!IsRealValued(array->data_type)) {
178       // Ignore non-real data types.
179       continue;
180     }
181     // The enum value QUANTIZED_UINT8 for --inference_type and
182     // --inference_input_type has long meant just 'QUANTIZED', being used as
183     // well in mixed 8-bit / 16-bit quantized models. However,
184     // ConvertIODataTypeToArrayDataType still interpretes it as meaning 8bit,
185     // and people have run into issues in the situation where they have an
186     // already mixed 8-bit / 16-bit quantized model in TFLITE format and
187     // want to run it again through toco, without having to re-specify all the
188     // extra array info that was used in the (complicated) process of initially
189     // quantizing that model. In order to have --inference_type=QUANTIZED_UINT8
190     // just work in that case, we implement the logic that when an array is
191     // already quantized, if  --inference_type is quantized (so we're not
192     // asking to dequantize here), no change of quantized data type is to be
193     // recorded.
194     if (array->data_type != toco::ArrayDataType::kFloat &&
195         type != toco::ArrayDataType::kFloat) {
196       continue;
197     }
198 
199     array->final_data_type = type;
200   }
201 }
202 
203 }  // namespace
204 
Import(const TocoFlags & toco_flags,const ModelFlags & model_flags,const string & input_file_contents)205 std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
206                               const ModelFlags& model_flags,
207                               const string& input_file_contents) {
208   std::unique_ptr<Model> model;
209   switch (toco_flags.input_format()) {
210     case TENSORFLOW_GRAPHDEF: {
211       TensorFlowImportFlags tf_import_flags;
212       tf_import_flags.drop_control_dependency =
213           toco_flags.has_drop_control_dependency()
214               ? toco_flags.drop_control_dependency()
215               : (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
216 
217       tf_import_flags.import_all_ops_as_unsupported =
218           toco_flags.force_select_tf_ops();
219 
220       model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
221                                        input_file_contents);
222       break;
223     }
224     case TFLITE:
225       model = toco::tflite::Import(model_flags, input_file_contents);
226       ResolveModelFlags(model_flags, model.get());
227       CheckInvariants(*model);
228       break;
229     default:
230       LOG(FATAL) << "Unhandled input_format='"
231                  << FileFormat_Name(toco_flags.input_format()) << "'";
232   }
233 
234   LogDump(kLogLevelModelChanged, "AT IMPORT", *model);
235 
236   return model;
237 }
238 
TransformWithStatus(const TocoFlags & toco_flags,Model * model)239 tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
240                                        Model* model) {
241   const FileFormat output_format = toco_flags.output_format();
242   const IODataType inference_type = toco_flags.inference_type();
243 
244   const bool quantize_output =
245       SupportsQuantization(output_format) &&
246       (inference_type == QUANTIZED_UINT8 || inference_type == QUANTIZED_INT16);
247 
248   if (quantize_output) {
249     QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
250         << "Quantized inference is not allowed with float inputs.";
251   }
252 
253   // Clean up after import.
254   SetFinalDataTypeOnInputs(toco_flags, model);
255   UseArraysExtraInfo(model, quantize_output);
256   FinishBuildingRNNStates(model);
257 
258   // Remove unused ops before performing any other optimizations. This is to
259   // stop optimizations from crossing the input/output boundaries. For example
260   // this will stop BatchNorm fusing if the output node is in between a conv
261   // and BatchNorm layers.
262   TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
263       model, "Removing unused ops", {new toco::RemoveUnusedOp}));
264 
265   GraphTransformationsSet transformations;
266   MakeGeneralGraphTransformationsSet(&transformations);
267   auto* remove_trivial_reshape = new RemoveTrivialReshape;
268   transformations.Add(remove_trivial_reshape);
269   auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant;
270   if (quantize_output) {
271     resolve_constant_fake_quant->set_propagate_fake_quant_num_bits(
272         toco_flags.propagate_fake_quant_num_bits());
273   }
274   transformations.Add(resolve_constant_fake_quant);
275   if (SupportsFusedActivationFunction(output_format)) {
276     transformations.Add(new FuseActivationFunctions);
277   } else {
278     transformations.Add(new UnfuseActivationFunctions);
279   }
280   if (toco_flags.drop_fake_quant()) {
281     transformations.Add(new DropFakeQuant);
282   } else {
283     // See the doc for --reorder_across_fake_quant: that flag is needed to
284     // support some existing models, e.g. WordLens, that have FakeQuant
285     // nodes in the wrong places.
286     // TODO(benoitjacob): drop special casing when we can.
287     if ((quantize_output && toco_flags.reorder_across_fake_quant())) {
288       transformations.Add(new DropFakeQuant);
289     }
290   }
291   transformations.Add(new ConvertPureConvToDepthwise);
292   if (SupportsLstmCell(output_format)) {
293     if (!toco_flags.debug_disable_recurrent_cell_fusion()) {
294       transformations.Add(new IdentifyLstmCell);
295     }
296     if (output_format == TFLITE && toco_flags.split_tflite_lstm_inputs()) {
297       transformations.Add(new toco::SplitLstmCellInputs);
298     } else {
299       transformations.Add(new toco::MergeLstmCellInputs);
300     }
301   }
302   transformations.Add(new ResolveConstantConcatenation);
303   // TODO(b/116063589): TF GraphDef doesn't support dilations on its depthwise
304   // conv, so we need to make sure we don't convert to dilated depthwise conv
305   // when outputing to TF GraphDef.
306   auto* identify_dilated_conv = new IdentifyDilatedConv;
307   if (output_format == TENSORFLOW_GRAPHDEF) {
308     identify_dilated_conv->set_identify_depthwise_conv(false);
309   }
310   transformations.Add(identify_dilated_conv);
311   TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
312       model, "general graph transformations", transformations));
313 
314   if (quantize_output) {
315     if (toco_flags.propagate_fake_quant_num_bits()) {
316       TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
317           model, "fake quant propagation graph transformations",
318           {new PropagateFakeQuantNumBits}));
319     }
320     TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
321         model, "pre-quantization graph transformations",
322         {
323             new HardcodeMinMax,
324             new DropFakeQuant,
325         }));
326   }
327 
328   // Try to merge bidirectional sequence lstm or rnn if present.
329   GraphTransformationsSet bidirectional_transformations;
330   bidirectional_transformations.Add(new RemoveUnusedOp);
331   bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm);
332   bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn);
333   bidirectional_transformations.Add(
334       new toco::GroupDynamicBidirectionalSequenceRnn);
335   bidirectional_transformations.Add(
336       new toco::GroupDynamicBidirectionalSequenceLstm);
337   TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
338       model, "Group bidirectional sequence lstm/rnn",
339       bidirectional_transformations));
340 
341   // Fix any issues with IO edges. This must happen after any transform that
342   // may modify the structure of the edges.
343   FixEdgeArrays(model);
344   FixOperatorOrdering(model);
345 
346   if (quantize_output) {
347     // If the user specified default min/max ranges we need to set all arrays
348     // that didn't either have a min/max specified or get one set via
349     // HardcodeMinMax or PropagateFakeQuantNumBits. This may require running
350     // HardcodeMinMax to move changes through the graph as we make changes.
351     auto propagate_default_min_max =
352         absl::make_unique<PropagateDefaultMinMax>();
353     bool has_default_ranges_flag = (toco_flags.has_default_ranges_min() &&
354                                     toco_flags.has_default_ranges_max());
355     if (has_default_ranges_flag) {
356       propagate_default_min_max->DefineTypeRange(
357           ArrayDataType::kUint8, toco_flags.default_ranges_min(),
358           toco_flags.default_ranges_max());
359     }
360     if (toco_flags.has_default_int16_ranges_min() &&
361         toco_flags.has_default_int16_ranges_max()) {
362       propagate_default_min_max->DefineTypeRange(
363           ArrayDataType::kInt16, toco_flags.default_int16_ranges_min(),
364           toco_flags.default_int16_ranges_max());
365     }
366     if (propagate_default_min_max->has_any_ranges_defined()) {
367       TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
368           model, "default min-max range propagation graph transformations",
369           {
370               propagate_default_min_max.release(),
371               new HardcodeMinMax,
372           }));
373     }
374 
375     CheckIsReadyForQuantization(*model);
376     auto* ensure_safe_for_int8_kernels =
377         new EnsureUint8WeightsSafeForFastInt8Kernels;
378     ensure_safe_for_int8_kernels->set_allow_nudging_weights(
379         toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
380     ensure_safe_for_int8_kernels->set_has_default_ranges_flag(
381         has_default_ranges_flag);
382     TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
383         model, "quantization graph transformations",
384         {
385             new RemoveTrivialQuantizedActivationFunc,
386             new RemoveTrivialQuantizedMinMax,
387             new Quantize,
388             new RemoveFinalDequantizeOp,
389             ensure_safe_for_int8_kernels,
390         }));
391     if (SupportsShuffledFCWeights(output_format)) {
392       TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
393           model, "shuffling of FC weights", {new ShuffleFCWeights}));
394     }
395   } else {
396     GraphTransformationsSet dequantization_transformations{new Dequantize};
397     // Dequantize creates FakeQuant nodes. We may want to discard
398     // those immediately.
399     if (toco_flags.drop_fake_quant()) {
400       dequantization_transformations.Add(new DropFakeQuant);
401     }
402 
403     TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
404         model, "dequantization graph transformations",
405         dequantization_transformations));
406   }
407 
408   if (output_format == TENSORFLOW_GRAPHDEF) {
409     EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model);
410   }
411 
412   // Deduplicate large constant arrays.
413   DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes());
414 
415   LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
416 
417   if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) {
418     // By now there shouldn't be any unsupported ops when exporting to
419     // TensorFlow GraphDef.
420     CheckUnsupportedOperations(*model);
421   }
422 
423   if (SupportsPreallocatedWorkspace(output_format)) {
424     AllocateTransientArrays(model, kDefaultTransientDataAlignment);
425     LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
426   }
427 
428   CheckModelCounts(*model);
429   CheckFinalDataTypesSatisfied(*model);
430 
431   int64 ops_count;
432   if (EstimateArithmeticOpsCount(*model, &ops_count)) {
433     LOG(INFO) << "Estimated count of arithmetic ops: " << 1e-9 * ops_count
434               << " billion (note that a multiply-add is counted as 2 ops).";
435   }
436   model->ops_count = ops_count;
437   return tensorflow::Status::OK();
438 }
439 
Export(const TocoFlags & toco_flags,const Model & model,bool allow_custom_ops,string * output_file_contents)440 tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
441                           bool allow_custom_ops, string* output_file_contents) {
442   switch (toco_flags.output_format()) {
443     case TENSORFLOW_GRAPHDEF:
444       ExportTensorFlowGraphDef(model, output_file_contents);
445       break;
446     case TFLITE: {
447       toco::tflite::ExportParams params;
448 
449       params.enable_select_tf_ops =
450           toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops();
451       params.allow_custom_ops = allow_custom_ops;
452       params.quantize_weights = toco_flags.post_training_quantize();
453 
454       auto status = toco::tflite::Export(model, output_file_contents, params);
455       if (!status.ok()) {
456         LOG(ERROR) << status.error_message();
457       }
458       return status;
459     } break;
460     case GRAPHVIZ_DOT:
461       DumpGraphviz(model, output_file_contents, "Computation Graph");
462       break;
463     default:
464       LOG(FATAL) << "Unhandled output_format='"
465                  << FileFormat_Name(toco_flags.output_format()) << "'";
466   }
467   return tensorflow::Status();
468 }
469 
470 }  // namespace toco
471