1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tooling_util.h"
16 
17 #include <functional>
18 #include <iterator>
19 #include <set>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 
24 #include "absl/strings/ascii.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/str_replace.h"
28 #include "absl/strings/str_split.h"
29 #include "re2/re2.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/lite/toco/dump_graphviz.h"
33 #include "tensorflow/lite/toco/model_flags.pb.h"
34 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
35 
36 namespace toco {
37 
38 // Find the longest common prefix of two strings.
FindLongestCommonPrefix(absl::string_view a,absl::string_view b)39 absl::string_view FindLongestCommonPrefix(absl::string_view a,
40                                           absl::string_view b) {
41   if (a.empty() || b.empty()) return absl::string_view();
42 
43   const char* pa = a.data();
44   const char* pb = b.data();
45   size_t count = 0;
46   const size_t limit = std::min(a.size(), b.size());
47   while (count < limit && *pa == *pb) {
48     ++pa;
49     ++pb;
50     ++count;
51   }
52 
53   return absl::string_view(a.data(), count);
54 }
55 
LogName(const Operator & op)56 string LogName(const Operator& op) {
57   const string& opname = HelpfulOperatorTypeName(op);
58   if (op.outputs.empty()) {
59     return toco::port::StringF("{%s operator}", opname);
60   } else {
61     return toco::port::StringF("{%s operator with output %s}", opname,
62                                op.outputs[0]);
63   }
64 }
65 
ArrayDataTypeName(ArrayDataType data_type)66 string ArrayDataTypeName(ArrayDataType data_type) {
67   switch (data_type) {
68     case ArrayDataType::kFloat:
69       return "float";
70     case ArrayDataType::kInt8:
71       return "int8";
72     case ArrayDataType::kUint8:
73       return "uint8";
74     case ArrayDataType::kInt16:
75       return "int16";
76     case ArrayDataType::kUint16:
77       return "uint16";
78     case ArrayDataType::kInt32:
79       return "int32";
80     case ArrayDataType::kUint32:
81       return "uint32";
82     case ArrayDataType::kInt64:
83       return "int64";
84     case ArrayDataType::kUint64:
85       return "uint64";
86     case ArrayDataType::kString:
87       return "string";
88     case ArrayDataType::kBool:
89       return "bool";
90     case ArrayDataType::kComplex64:
91       return "complex64";
92     case ArrayDataType::kNone:
93       return "None";
94     default:
95       LOG(FATAL) << "Unhandled array data type " << static_cast<int>(data_type);
96   }
97 }
98 
IsInputArray(const Model & model,const string & array_name)99 bool IsInputArray(const Model& model, const string& array_name) {
100   for (const auto& input_array : model.flags.input_arrays()) {
101     if (array_name == input_array.name()) {
102       return true;
103     }
104   }
105   return false;
106 }
107 
IsOutputArray(const Model & model,const string & array_name)108 bool IsOutputArray(const Model& model, const string& array_name) {
109   for (const auto& output_array : model.flags.output_arrays()) {
110     if (array_name == output_array) {
111       return true;
112     }
113   }
114   return false;
115 }
116 
IsArrayConsumed(const Model & model,const string & name)117 bool IsArrayConsumed(const Model& model, const string& name) {
118   if (GetOpWithInput(model, name)) {
119     return true;
120   }
121   if (IsOutputArray(model, name)) {
122     return true;
123   }
124   for (const auto& rnn_state : model.flags.rnn_states()) {
125     if (rnn_state.back_edge_source_array() == name) {
126       return true;
127     }
128   }
129   return false;
130 }
131 
CountTrueOutputs(const Model & model,const Operator & op)132 int CountTrueOutputs(const Model& model, const Operator& op) {
133   int count = 0;
134   for (const string& output : op.outputs) {
135     if (IsArrayConsumed(model, output)) {
136       ++count;
137     }
138   }
139   return count;
140 }
141 
CountOpsWithInput(const Model & model,const string & array_name)142 int CountOpsWithInput(const Model& model, const string& array_name) {
143   int count = 0;
144   for (const auto& op : model.operators) {
145     for (auto& input : op->inputs) {
146       if (input == array_name) {
147         count++;
148         // Breaking here is important: some graphs have ops that use the
149         // same array as more than one of their inputs, and in that case
150         // we want it counted only once.
151         break;
152       }
153     }
154   }
155   return count;
156 }
157 
DeleteArrayIfUnused(const string & array_name,Model * model)158 bool DeleteArrayIfUnused(const string& array_name, Model* model) {
159   if (IsDiscardableArray(*model, array_name) &&
160       CountOpsWithInput(*model, array_name) == 0) {
161     model->EraseArray(array_name);
162     return true;
163   }
164   return false;
165 }
166 
DeleteArrayIfUsedOnce(const string & array_name,Model * model)167 bool DeleteArrayIfUsedOnce(const string& array_name, Model* model) {
168   if (IsDiscardableArray(*model, array_name) &&
169       CountOpsWithInput(*model, array_name) == 1) {
170     model->EraseArray(array_name);
171     return true;
172   }
173   return false;
174 }
175 
DeleteOpAndArraysIfUnused(Model * model,const Operator * op)176 void DeleteOpAndArraysIfUnused(Model* model, const Operator* op) {
177   for (const string& array_name : op->inputs) {
178     DeleteArrayIfUsedOnce(array_name, model);
179   }
180   auto op_it = FindOp(*model, op);
181   CHECK(op_it != model->operators.end());
182   model->operators.erase(op_it);
183 }
184 
FindOpWithOutput(const Model & model,const string & array_name)185 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
186     const Model& model, const string& array_name) {
187   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
188     for (auto& output : it->get()->outputs) {
189       if (output == array_name) {
190         return it;
191       }
192     }
193   }
194   return model.operators.end();
195 }
196 
FindOpWithOutput(Model & model,const string & array_name)197 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
198     Model& model, const string& array_name) {
199   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
200     for (auto& output : it->get()->outputs) {
201       if (output == array_name) {
202         return it;
203       }
204     }
205   }
206   return model.operators.end();
207 }
208 
GetOpWithOutput(const Model & model,const string & array_name)209 Operator* GetOpWithOutput(const Model& model, const string& array_name) {
210   auto it = FindOpWithOutput(model, array_name);
211   return it == model.operators.end() ? nullptr : it->get();
212 }
213 
214 // GetFirstOpWithInput assumes that this finds the first op.
FindOpWithInput(const Model & model,const string & array_name)215 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
216     const Model& model, const string& array_name) {
217   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
218     for (auto& input : it->get()->inputs) {
219       if (input == array_name) {
220         return it;
221       }
222     }
223   }
224   return model.operators.end();
225 }
226 
FindOpWithInput(Model & model,const string & array_name)227 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput(
228     Model& model, const string& array_name) {
229   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
230     for (auto& input : it->get()->inputs) {
231       if (input == array_name) {
232         return it;
233       }
234     }
235   }
236   return model.operators.end();
237 }
238 
FindOp(const Model & model,const Operator * op)239 std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
240     const Model& model, const Operator* op) {
241   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
242     if (it->get() == op) {
243       return it;
244     }
245   }
246   return model.operators.end();
247 }
248 
FindOp(Model & model,const Operator * op)249 std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
250                                                         const Operator* op) {
251   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
252     if (it->get() == op) {
253       return it;
254     }
255   }
256   return model.operators.end();
257 }
258 
GetOpWithInput(const Model & model,const string & array_name)259 Operator* GetOpWithInput(const Model& model, const string& array_name) {
260   auto it = FindOpWithInput(model, array_name);
261   return it == model.operators.end() ? nullptr : it->get();
262 }
263 
GetFirstOpWithInput(const Model & model,const string & array_name)264 Operator* GetFirstOpWithInput(const Model& model, const string& array_name) {
265   auto it = FindOpWithInput(model, array_name);
266   return it == model.operators.end() ? nullptr : it->get();
267 }
268 
ReplaceArrayUsage(Model * model,const string & old_array_name,const string & new_array_name)269 void ReplaceArrayUsage(Model* model, const string& old_array_name,
270                        const string& new_array_name) {
271   for (auto& op_it : model->operators) {
272     Operator* op = op_it.get();
273     for (size_t i = 0; i < op->inputs.size(); ++i) {
274       if (op->inputs[i] == old_array_name) {
275         op->inputs[i] = new_array_name;
276       }
277     }
278     for (size_t i = 0; i < op->outputs.size(); ++i) {
279       if (op->outputs[i] == old_array_name) {
280         op->outputs[i] = new_array_name;
281       }
282     }
283   }
284 }
285 
FormatArraysList(const Model & model,const std::vector<string> & list)286 string FormatArraysList(const Model& model, const std::vector<string>& list) {
287   if (list.empty()) {
288     return "[]";
289   }
290   string result = "";
291   if (list.size() > 1) {
292     result += "[ ";
293   }
294   for (std::size_t i = 0; i < list.size(); i++) {
295     if (i > 0) {
296       result += ", ";
297     }
298     result += list[i];
299   }
300   if (list.size() > 1) {
301     result += " ]";
302   }
303   return result;
304 }
305 
OperatorTypeName(OperatorType type)306 const char* OperatorTypeName(OperatorType type) {
307   switch (type) {
308 #define HANDLE_OPERATORTYPENAME_CASE(c) \
309   case OperatorType::k##c:              \
310     return #c;
311     HANDLE_OPERATORTYPENAME_CASE(Abs)
312     HANDLE_OPERATORTYPENAME_CASE(Add)
313     HANDLE_OPERATORTYPENAME_CASE(AddN)
314     HANDLE_OPERATORTYPENAME_CASE(AveragePool)
315     HANDLE_OPERATORTYPENAME_CASE(BatchMatMul)
316     HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
317     HANDLE_OPERATORTYPENAME_CASE(Conv)
318     HANDLE_OPERATORTYPENAME_CASE(Concatenation)
319     HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv)
320     HANDLE_OPERATORTYPENAME_CASE(DepthToSpace)
321     HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth)
322     HANDLE_OPERATORTYPENAME_CASE(FullyConnected)
323     HANDLE_OPERATORTYPENAME_CASE(Dequantize)
324     HANDLE_OPERATORTYPENAME_CASE(L2Normalization)
325     HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization)
326     HANDLE_OPERATORTYPENAME_CASE(Log)
327     HANDLE_OPERATORTYPENAME_CASE(Logistic)
328     HANDLE_OPERATORTYPENAME_CASE(LstmCell)
329     HANDLE_OPERATORTYPENAME_CASE(MaxPool)
330     HANDLE_OPERATORTYPENAME_CASE(L2Pool)
331     HANDLE_OPERATORTYPENAME_CASE(FakeQuant)
332     HANDLE_OPERATORTYPENAME_CASE(Mul)
333     HANDLE_OPERATORTYPENAME_CASE(RandomUniform)
334     HANDLE_OPERATORTYPENAME_CASE(Elu)
335     HANDLE_OPERATORTYPENAME_CASE(Relu)
336     HANDLE_OPERATORTYPENAME_CASE(Relu1)
337     HANDLE_OPERATORTYPENAME_CASE(Relu6)
338     HANDLE_OPERATORTYPENAME_CASE(PRelu)
339     HANDLE_OPERATORTYPENAME_CASE(ReorderAxes)
340     HANDLE_OPERATORTYPENAME_CASE(Softmax)
341     HANDLE_OPERATORTYPENAME_CASE(LogSoftmax)
342     HANDLE_OPERATORTYPENAME_CASE(Div)
343     HANDLE_OPERATORTYPENAME_CASE(Tanh)
344     HANDLE_OPERATORTYPENAME_CASE(Sin)
345     HANDLE_OPERATORTYPENAME_CASE(All)
346     HANDLE_OPERATORTYPENAME_CASE(Assert)
347     HANDLE_OPERATORTYPENAME_CASE(ExpandDims)
348     HANDLE_OPERATORTYPENAME_CASE(Fill)
349     HANDLE_OPERATORTYPENAME_CASE(FloorMod)
350     HANDLE_OPERATORTYPENAME_CASE(FloorDiv)
351     HANDLE_OPERATORTYPENAME_CASE(Greater)
352     HANDLE_OPERATORTYPENAME_CASE(GreaterEqual)
353     HANDLE_OPERATORTYPENAME_CASE(Identity)
354     HANDLE_OPERATORTYPENAME_CASE(Less)
355     HANDLE_OPERATORTYPENAME_CASE(LessEqual)
356     HANDLE_OPERATORTYPENAME_CASE(MatMul)
357     HANDLE_OPERATORTYPENAME_CASE(ReduceMax)  //  Reduction Max
358     HANDLE_OPERATORTYPENAME_CASE(Maximum)    //  Element-wise Maximum
359     HANDLE_OPERATORTYPENAME_CASE(Merge)
360     HANDLE_OPERATORTYPENAME_CASE(ReduceMin)  //  Reduction Min
361     HANDLE_OPERATORTYPENAME_CASE(Minimum)    //  Element-wise Minimum
362     HANDLE_OPERATORTYPENAME_CASE(Neg)
363     HANDLE_OPERATORTYPENAME_CASE(OneHot)
364     HANDLE_OPERATORTYPENAME_CASE(Pack)
365     HANDLE_OPERATORTYPENAME_CASE(Pad)
366     HANDLE_OPERATORTYPENAME_CASE(PadV2)
367     HANDLE_OPERATORTYPENAME_CASE(StridedSlice)
368     HANDLE_OPERATORTYPENAME_CASE(Range)
369     HANDLE_OPERATORTYPENAME_CASE(Rank)
370     HANDLE_OPERATORTYPENAME_CASE(Reshape)
371     HANDLE_OPERATORTYPENAME_CASE(Squeeze)
372     HANDLE_OPERATORTYPENAME_CASE(Rsqrt)
373     HANDLE_OPERATORTYPENAME_CASE(Shape)
374     HANDLE_OPERATORTYPENAME_CASE(Slice)
375     HANDLE_OPERATORTYPENAME_CASE(Split)
376     HANDLE_OPERATORTYPENAME_CASE(SplitV)
377     HANDLE_OPERATORTYPENAME_CASE(Sqrt)
378     HANDLE_OPERATORTYPENAME_CASE(Square)
379     HANDLE_OPERATORTYPENAME_CASE(Switch)
380     HANDLE_OPERATORTYPENAME_CASE(Sub)
381     HANDLE_OPERATORTYPENAME_CASE(Sum)
382     HANDLE_OPERATORTYPENAME_CASE(Tile)
383     HANDLE_OPERATORTYPENAME_CASE(Transpose)
384     HANDLE_OPERATORTYPENAME_CASE(TransposeConv)
385     HANDLE_OPERATORTYPENAME_CASE(Concat)
386     HANDLE_OPERATORTYPENAME_CASE(ConcatV2)
387     HANDLE_OPERATORTYPENAME_CASE(Cast)
388     HANDLE_OPERATORTYPENAME_CASE(Floor)
389     HANDLE_OPERATORTYPENAME_CASE(Ceil)
390     HANDLE_OPERATORTYPENAME_CASE(Gather)
391     HANDLE_OPERATORTYPENAME_CASE(GatherNd)
392     HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)
393     HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND)
394     HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND)
395     HANDLE_OPERATORTYPENAME_CASE(Mean)
396     HANDLE_OPERATORTYPENAME_CASE(ReduceProd)
397     HANDLE_OPERATORTYPENAME_CASE(Svdf)
398     HANDLE_OPERATORTYPENAME_CASE(ArgMax)
399     HANDLE_OPERATORTYPENAME_CASE(ArgMin)
400     HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
401     HANDLE_OPERATORTYPENAME_CASE(Unsupported)
402     HANDLE_OPERATORTYPENAME_CASE(Exp)
403     HANDLE_OPERATORTYPENAME_CASE(DynamicPartition)
404     HANDLE_OPERATORTYPENAME_CASE(DynamicStitch)
405     HANDLE_OPERATORTYPENAME_CASE(Select)
406     HANDLE_OPERATORTYPENAME_CASE(SparseToDense)
407     HANDLE_OPERATORTYPENAME_CASE(Equal)
408     HANDLE_OPERATORTYPENAME_CASE(NotEqual)
409     HANDLE_OPERATORTYPENAME_CASE(Pow)
410     HANDLE_OPERATORTYPENAME_CASE(Any)
411     HANDLE_OPERATORTYPENAME_CASE(LogicalAnd)
412     HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
413     HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
414     HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
415     HANDLE_OPERATORTYPENAME_CASE(Unpack)
416     HANDLE_OPERATORTYPENAME_CASE(ZerosLike)
417     HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceLstm)
418     HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceLstm)
419     HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceRnn)
420     HANDLE_OPERATORTYPENAME_CASE(ResizeNearestNeighbor)
421     HANDLE_OPERATORTYPENAME_CASE(LeakyRelu)
422     HANDLE_OPERATORTYPENAME_CASE(SquaredDifference)
423     HANDLE_OPERATORTYPENAME_CASE(MirrorPad)
424     HANDLE_OPERATORTYPENAME_CASE(Unique)
425     HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
426     HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
427     HANDLE_OPERATORTYPENAME_CASE(Cos)
428     HANDLE_OPERATORTYPENAME_CASE(Where)
429     HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
430     default:
431       LOG(FATAL) << "Unhandled op type";
432 #undef HANDLE_OPERATORTYPENAME_CASE
433   }
434 }
435 
HelpfulOperatorTypeName(const Operator & op)436 string HelpfulOperatorTypeName(const Operator& op) {
437   if (op.type == OperatorType::kUnsupported) {
438     return toco::port::StringF(
439         "(Unsupported TensorFlow op: %s)",
440         static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op);
441   }
442   return OperatorTypeName(op.type);
443 }
444 
OperatorSupportsFusedActivation(OperatorType type)445 bool OperatorSupportsFusedActivation(OperatorType type) {
446   switch (type) {
447     case OperatorType::kAdd:
448     case OperatorType::kAveragePool:
449     case OperatorType::kBatchNormalization:
450     case OperatorType::kConv:
451     case OperatorType::kDepthwiseConv:
452     case OperatorType::kDiv:
453     case OperatorType::kFullyConnected:
454     case OperatorType::kL2Pool:
455     case OperatorType::kMaxPool:
456     case OperatorType::kMul:
457     case OperatorType::kSub:
458     case OperatorType::kSquaredDifference:
459       return true;
460     default:
461       return false;
462   }
463 }
464 
LogSummary(int log_level,const Model & model)465 void LogSummary(int log_level, const Model& model) {
466   VLOG(log_level) << "Operators summary (" << model.operators.size()
467                   << " operators):";
468   std::unordered_multiset<OperatorType> ops_by_type;
469   for (const auto& op : model.operators) {
470     ops_by_type.insert(op->type);
471   }
472   auto it = ops_by_type.begin();
473   while (it != ops_by_type.end()) {
474     int count = ops_by_type.count(*it);
475     VLOG(log_level) << "    " << OperatorTypeName(*it) << ": " << count;
476     std::advance(it, count);
477   }
478 }
479 
LogArray(int log_level,const Model & model,const string & name)480 void LogArray(int log_level, const Model& model, const string& name) {
481   VLOG(log_level) << "Array: " << name;
482   if (!model.HasArray(name)) {
483     VLOG(log_level) << "  DOES NOT EXIST";
484     return;
485   }
486   const auto& array = model.GetArray(name);
487   VLOG(log_level) << "  Data type: " << ArrayDataTypeName(array.data_type);
488   VLOG(log_level) << "  Final type: "
489                   << ArrayDataTypeName(array.final_data_type);
490   if (array.buffer) {
491     VLOG(log_level) << "  Constant Buffer";
492   }
493   if (array.alloc) {
494     VLOG(log_level) << "  Transient Alloc";
495   }
496   if (array.has_shape()) {
497     const Shape& array_shape = array.shape();
498     if (array_shape.dimensions_count() == 0) {
499       VLOG(log_level) << "  (Zero dimensions)";
500     } else {
501       string message = "  Dims: ";
502       bool first = true;
503       for (const int dim : array_shape.dims()) {
504         if (!first) {
505           message += ", ";
506         }
507         first = false;
508         toco::port::AppendF(&message, "%d", dim);
509       }
510       VLOG(log_level) << message;
511     }
512   }
513   if (array.minmax) {
514     VLOG(log_level) << "  MinMax: " << array.minmax->min << " .. "
515                     << array.minmax->max;
516   }
517   if (array.quantization_params) {
518     VLOG(log_level) << "  QuantizationParams: zero_point="
519                     << static_cast<int>(array.quantization_params->zero_point)
520                     << ", scale=" << array.quantization_params->scale;
521   }
522 }
523 
DumpGraphvizVideoFrame(const Model & model)524 void DumpGraphvizVideoFrame(const Model& model) {
525   namespace port = toco::port;
526 
527   const auto& dump_options = *GraphVizDumpOptions::singleton();
528   if (!dump_options.dump_graphviz_video) {
529     return;
530   }
531   CHECK(!dump_options.dump_graphviz.empty());
532   // TODO(benoitjacob): the static data here means that this function
533   // is stateful, not reentrant, and effectively leaks memory till exit
534   // (since dump_hashes can only grow in size). It also means that it
535   // really only is intended to be called for a single model during the
536   // process' lifetime. So it's not great design at all. The overriding
537   // design aspect here is to make the video-dumping code as unintrusive
538   // and self-contained as possible. Eventually, we'll want to have that
539   // cleaned-up, but that will require some form of general statefulness
540   // in toco (some kind of 'tooling state' data structure) that does
541   // not exist at present, and would be premature to design here just for
542   // this new video-dumping feature.
543   static int dump_id = 0;
544   static std::unordered_set<std::size_t> dump_hashes;
545   string graphviz_dump;
546   DumpGraphviz(model, &graphviz_dump,
547                toco::port::StringF("VIDEO frame:%05d", dump_id));
548   std::size_t hash = std::hash<string>{}(graphviz_dump);
549   if (!dump_hashes.count(hash)) {
550     LOG(INFO) << "DUMPING GRAPHVIZ VIDEO FRAME: " << dump_id;
551     dump_hashes.insert(hash);
552     const auto result = port::file::SetContents(
553         port::file::JoinPath(
554             dump_options.dump_graphviz,
555             toco::port::StringF("toco_video_%05d.dot", dump_id)),
556         graphviz_dump, port::file::Defaults());
557     QCHECK(result.ok()) << result.error_message();
558     dump_id++;
559   }
560 }
561 
LogDump(int log_level,const string & message,const Model & model)562 void LogDump(int log_level, const string& message, const Model& model) {
563   namespace port = toco::port;
564   const auto& dump_options = *GraphVizDumpOptions::singleton();
565 
566   DumpGraphvizVideoFrame(model);
567   if (!dump_options.dump_graphviz.empty()) {
568     string graphviz_dump;
569 
570     DumpGraphviz(model, &graphviz_dump, message);
571     const auto result = port::file::SetContents(
572         port::file::JoinPath(
573             dump_options.dump_graphviz,
574             absl::StrCat("toco_", absl::StrReplaceAll(message, {{" ", "_"}}),
575                          ".dot")),
576         graphviz_dump, port::file::Defaults());
577     QCHECK(result.ok()) << result.error_message();
578   }
579 
580   if (!VLOG_IS_ON(log_level)) {
581     return;
582   }
583   VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")";
584   LogSummary(log_level, model);
585   std::unordered_set<string> already_printed_arrays;
586   for (const auto& op : model.operators) {
587     for (const auto& input : op->inputs) {
588       if (!already_printed_arrays.count(input)) {
589         already_printed_arrays.insert(input);
590         LogArray(log_level, model, input);
591       }
592     }
593     VLOG(log_level) << HelpfulOperatorTypeName(*op) << " :";
594     VLOG(log_level) << "  " << FormatArraysList(model, op->inputs) << " -> "
595                     << FormatArraysList(model, op->outputs);
596     if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
597       VLOG(log_level) << "    (with fused activation function)";
598     }
599     for (const auto& output : op->outputs) {
600       if (!already_printed_arrays.count(output)) {
601         already_printed_arrays.insert(output);
602         LogArray(log_level, model, output);
603       }
604     }
605   }
606   VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")";
607 }
608 
609 // Note remaining raw-array extension in ProcessTensorFlowReshapeOperator().
ExtendShape(Shape * shape,int new_shape_size)610 void ExtendShape(Shape* shape, int new_shape_size) {
611   CHECK_GE(new_shape_size, shape->dimensions_count());
612   const int size_increase = new_shape_size - shape->dimensions_count();
613   auto* shape_dims = shape->mutable_dims();
614   shape_dims->insert(shape_dims->begin(), size_increase, 1);
615 }
616 
617 // TODO(b/62904716) Remove along with remaining uses.
UnextendShape(Shape * shape,int new_shape_size)618 void UnextendShape(Shape* shape, int new_shape_size) {
619   CHECK_LE(new_shape_size, shape->dimensions_count());
620   const int size_reduction = shape->dimensions_count() - new_shape_size;
621   for (int i = 0; i < size_reduction; i++) {
622     CHECK_EQ(shape->dims(i), 1);
623   }
624   std::vector<int>& shape_dims = *shape->mutable_dims();
625   shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
626 }
627 
628 // In general, zero-sized dimensions are disallowed, but there are exceptions,
629 // e.g., if the tensor data itself represents a scalar (rank 0) shape, its
630 // shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more
631 // strict, and is appropriate for ops and comparisons where an empty shape
632 // doesn't make sense.
633 template <typename Dims>
CheckValidShapeDimensions(const Dims & dims)634 void CheckValidShapeDimensions(const Dims& dims) {
635   if (dims.size() == 1 && dims[0] == 0) {
636     return;
637   }
638   for (const auto& dim : dims) {
639     CHECK_GE(dim, 1);
640   }
641 }
642 
CheckValidShape(const Shape & shape)643 void CheckValidShape(const Shape& shape) {
644   CheckValidShapeDimensions(shape.dims());
645 }
646 
IsNonEmpty(const Shape & shape)647 bool IsNonEmpty(const Shape& shape) {
648   for (int i = 0; i < shape.dimensions_count(); ++i) {
649     if (shape.dims(i) < 1) return false;
650   }
651   return true;
652 }
653 
CheckNonEmptyShapeDimensions(const Shape & shape)654 void CheckNonEmptyShapeDimensions(const Shape& shape) {
655   for (int i = 0; i < shape.dimensions_count(); ++i) {
656     CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
657                                  << ". shape = " << ShapeToString(shape);
658   }
659 }
660 
ShapesAgreeUpToBroadcasting(const Shape & shape0,const Shape & shape1)661 bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
662   CheckNonEmptyShapeDimensions(shape0);
663   CheckNonEmptyShapeDimensions(shape1);
664 
665   const Shape* longer = &shape0;
666   const Shape* shorter = &shape1;
667   if (shape1.dimensions_count() > shape0.dimensions_count()) {
668     longer = &shape1;
669     shorter = &shape0;
670   }
671 
672   // Walk dimensions back to front until we run out of dimensions in the shorter
673   // shape.
674   int longer_index = longer->dimensions_count() - 1;
675   int shorter_index = shorter->dimensions_count() - 1;
676   while (shorter_index >= 0) {
677     const int d_long = longer->dims(longer_index);
678     const int d_short = shorter->dims(shorter_index);
679     // Broadcasting fails if the dimensions are different *and* neither is 1.
680     if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) {
681       return false;
682     }
683     longer_index--;
684     shorter_index--;
685   }
686   return true;
687 }
688 
ShapesAgreeUpToExtending(const Shape & shape0,const Shape & shape1)689 bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
690   CheckNonEmptyShapeDimensions(shape0);
691   CheckNonEmptyShapeDimensions(shape1);
692 
693   const Shape* longer = &shape0;
694   const Shape* shorter = &shape1;
695   if (shape1.dimensions_count() > shape0.dimensions_count()) {
696     longer = &shape1;
697     shorter = &shape0;
698   }
699 
700   // Walk dimensions back to front until we run out of dimensions in the shorter
701   // shape.
702   int longer_index = longer->dimensions_count() - 1;
703   int shorter_index = shorter->dimensions_count() - 1;
704   while (shorter_index >= 0) {
705     const int d_long = longer->dims(longer_index);
706     const int d_short = shorter->dims(shorter_index);
707     // Extending fails if the dimensions are different.
708     if (d_long != d_short) {
709       return false;
710     }
711     longer_index--;
712     shorter_index--;
713   }
714 
715   // The remaining dimensions in the longer shape must be 1.
716   while (longer_index >= 0) {
717     const int d_long = longer->dims(longer_index);
718     if (d_long != 1) {
719       return false;
720     }
721     longer_index--;
722   }
723 
724   return true;
725 }
726 
RequiredBufferSizeForShape(const Shape & shape)727 int RequiredBufferSizeForShape(const Shape& shape) {
728   CheckValidShape(shape);
729   int max_offset = 1;
730   for (const auto& dim : shape.dims()) {
731     max_offset *= dim;
732   }
733   return max_offset;
734 }
735 
IsConstantParameterArray(const Model & model,const string & name)736 bool IsConstantParameterArray(const Model& model, const string& name) {
737   if (!model.HasArray(name)) {
738     return false;
739   }
740 
741   return !!model.GetArray(name).buffer;
742 }
743 
744 namespace {
745 template <ArrayDataType A>
CompareArrayBuffers(const Array & lhs_array,const Array & rhs_array)746 bool CompareArrayBuffers(const Array& lhs_array, const Array& rhs_array) {
747   CHECK(lhs_array.data_type == rhs_array.data_type) << "Data types must match";
748   CHECK(lhs_array.buffer) << "LHS must be constant";
749   CHECK(rhs_array.buffer) << "RHS must be constant";
750   const auto& lhs_data = lhs_array.GetBuffer<A>().data;
751   const auto& rhs_data = rhs_array.GetBuffer<A>().data;
752   CHECK_EQ(lhs_data.size(), rhs_data.size())
753       << "Buffer sizes must match in element count";
754   for (int i = 0; i < lhs_data.size(); ++i) {
755     if (lhs_data[i] != rhs_data[i]) {
756       return false;
757     }
758   }
759   return true;
760 }
761 
HaveSameMinMax(const Array & lhs_array,const Array & rhs_array)762 bool HaveSameMinMax(const Array& lhs_array, const Array& rhs_array) {
763   if (lhs_array.minmax || rhs_array.minmax) {
764     if (!lhs_array.minmax || !rhs_array.minmax) {
765       return false;
766     }
767     if (!(*lhs_array.minmax == *rhs_array.minmax)) {
768       return false;
769     }
770   }
771   return true;
772 }
773 
HaveSameQuantizationParams(const Array & lhs_array,const Array & rhs_array)774 bool HaveSameQuantizationParams(const Array& lhs_array,
775                                 const Array& rhs_array) {
776   if (lhs_array.quantization_params || rhs_array.quantization_params) {
777     if (!lhs_array.quantization_params || !rhs_array.quantization_params) {
778       return false;
779     }
780     if (!(*lhs_array.quantization_params == *rhs_array.quantization_params)) {
781       return false;
782     }
783   }
784   return true;
785 }
786 
787 }  // namespace
788 
CompareConstantArrays(const Array & lhs_array,const Array & rhs_array)789 bool CompareConstantArrays(const Array& lhs_array, const Array& rhs_array) {
790   bool attrs_equal = lhs_array.shape() == rhs_array.shape() &&
791                      lhs_array.data_type == rhs_array.data_type &&
792                      lhs_array.final_data_type == rhs_array.final_data_type &&
793                      HaveSameMinMax(lhs_array, rhs_array) &&
794                      HaveSameQuantizationParams(lhs_array, rhs_array) &&
795                      lhs_array.narrow_range == rhs_array.narrow_range;
796   if (!attrs_equal) {
797     return false;
798   }
799   switch (lhs_array.data_type) {
800     case ArrayDataType::kBool:
801       return CompareArrayBuffers<ArrayDataType::kBool>(lhs_array, rhs_array);
802     case ArrayDataType::kFloat:
803       return CompareArrayBuffers<ArrayDataType::kFloat>(lhs_array, rhs_array);
804     case ArrayDataType::kInt8:
805       return CompareArrayBuffers<ArrayDataType::kInt8>(lhs_array, rhs_array);
806     case ArrayDataType::kUint8:
807       return CompareArrayBuffers<ArrayDataType::kUint8>(lhs_array, rhs_array);
808     case ArrayDataType::kInt16:
809       return CompareArrayBuffers<ArrayDataType::kInt16>(lhs_array, rhs_array);
810     case ArrayDataType::kUint16:
811       return CompareArrayBuffers<ArrayDataType::kUint16>(lhs_array, rhs_array);
812     case ArrayDataType::kInt32:
813       return CompareArrayBuffers<ArrayDataType::kInt32>(lhs_array, rhs_array);
814     case ArrayDataType::kUint32:
815       return CompareArrayBuffers<ArrayDataType::kUint32>(lhs_array, rhs_array);
816     case ArrayDataType::kInt64:
817       return CompareArrayBuffers<ArrayDataType::kInt64>(lhs_array, rhs_array);
818     case ArrayDataType::kUint64:
819       return CompareArrayBuffers<ArrayDataType::kUint64>(lhs_array, rhs_array);
820     case ArrayDataType::kString:
821       return CompareArrayBuffers<ArrayDataType::kString>(lhs_array, rhs_array);
822     case ArrayDataType::kComplex64:
823       return CompareArrayBuffers<ArrayDataType::kComplex64>(lhs_array,
824                                                             rhs_array);
825     default:
826       LOG(FATAL) << "Unsupported data type: "
827                  << ArrayDataTypeName(lhs_array.data_type);
828       return false;
829   }
830 }
831 
832 namespace {
833 // Take an array name, which may be something like "name:3_5" and make it
834 // acceptable as a TF node name, say "name_3_5";
SanitizeNameForTFNode(const string & array_name)835 string SanitizeNameForTFNode(const string& array_name) {
836   auto node_name = array_name;
837   std::replace(node_name.begin(), node_name.end(), ':', '_');
838   return node_name;
839 }
840 
CheckInputArraysAreNotOutputArrays(const ModelFlags & model_flags)841 void CheckInputArraysAreNotOutputArrays(const ModelFlags& model_flags) {
842   for (const auto& input_array : model_flags.input_arrays()) {
843     for (const string& output_array : model_flags.output_arrays()) {
844       QCHECK_NE(input_array.name(), output_array)
845           << "The array " << output_array
846           << " is listed in both --input_arrays and --output_arrays.";
847     }
848   }
849 }
850 
IsAsciiPrintable(const string & name)851 bool IsAsciiPrintable(const string& name) {
852   for (char c : name) {
853     if (!absl::ascii_isprint(c)) {
854       return false;
855     }
856   }
857   return true;
858 }
859 
DumpAscii(const string & name)860 string DumpAscii(const string& name) {
861   string result;
862   port::AppendF(&result, "ASCII | Hex\n");
863   port::AppendF(&result, "------+----\n");
864   for (char c : name) {
865     if (absl::ascii_isprint(c)) {
866       port::AppendF(&result, "%c     | %x\n", c, c);
867     } else {
868       port::AppendF(&result, "      | %x   Not ASCII printable!\n", c);
869     }
870   }
871   return result;
872 }
873 
CheckNonAsciiIOArrays(const ModelFlags & model_flags)874 void CheckNonAsciiIOArrays(const ModelFlags& model_flags) {
875   if (model_flags.allow_nonascii_arrays()) {
876     return;
877   }
878   for (const auto& input_array : model_flags.input_arrays()) {
879     QCHECK(IsAsciiPrintable(input_array.name()))
880         << "Non-ASCII-printable character found in --input_arrays: "
881         << input_array.name()
882         << ". Pass --allow_nonascii_arrays to allow that. "
883         << "Here is a dump of the string:\n\n"
884         << DumpAscii(input_array.name());
885   }
886   for (const string& output_array : model_flags.output_arrays()) {
887     QCHECK(IsAsciiPrintable(output_array))
888         << "Non-ASCII-printable character found in --output_arrays: "
889         << output_array << ". Pass --allow_nonascii_arrays to allow that. "
890         << "Here is a dump of the string:\n\n"
891         << DumpAscii(output_array);
892   }
893 }
894 
CheckNonExistentIOArrays(const Model & model)895 void CheckNonExistentIOArrays(const Model& model) {
896   // "non-existent" is interpreted in the stronger sense of
897   // "not actually produced/consumed by an op".
898   // Rationale: we have to artificially fix up TensorFlow graphs by creating
899   // any array that it refers to, so just checking that arrays exist isn't
900   // sufficient. The real invariant here is whether arrays are produced/consumed
901   // by something.
902   if (model.flags.allow_nonexistent_arrays()) {
903     return;
904   }
905   static constexpr char general_comment[] =
906       "Is it a typo? To silence this message, pass this flag:  "
907       "allow_nonexistent_arrays";
908   for (const string& output_array : model.flags.output_arrays()) {
909     if (IsConstantParameterArray(model, output_array)) {
910       continue;  // It is OK to request that a constant be an output.
911     }
912     QCHECK(GetOpWithOutput(model, output_array))
913         << "Specified output array \"" << output_array
914         << "\" is not produced by any op in this graph. " << general_comment;
915   }
916   for (const auto& rnn_state : model.flags.rnn_states()) {
917     if (!rnn_state.discardable()) {
918       // Check that all RNN states are consumed
919       QCHECK(GetOpWithInput(model, rnn_state.state_array()))
920           << "Specified RNN state \"" << rnn_state.state_array()
921           << "\" is not consumed by any op in this graph. " << general_comment;
922       // Check that all RNN back-edge source arrays are produced
923       QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array()))
924           << "Specified RNN back-edge source array \""
925           << rnn_state.back_edge_source_array()
926           << "\" is not produced by any op in this graph. " << general_comment;
927     }
928   }
929 }
930 
931 }  // namespace
932 
CheckNoMissingArray(const Model & model)933 void CheckNoMissingArray(const Model& model) {
934   for (const auto& op : model.operators) {
935     for (const auto& input : op->inputs) {
936       CHECK(model.HasArray(input) || model.optional_arrays.count(input))
937           << "Input: " << input << " missing for op: " << op->outputs[0] << ".";
938     }
939     for (const auto& output : op->outputs) {
940       CHECK(model.HasArray(output)) << "Output: " << output << " missing.";
941     }
942   }
943   CheckNonExistentIOArrays(model);
944 }
945 
FixNoMissingArray(Model * model)946 void FixNoMissingArray(Model* model) {
947   for (const auto& op : model->operators) {
948     for (const auto& input : op->inputs) {
949       if (!model->HasArray(input) && !model->IsOptionalArray(input)) {
950         model->GetOrCreateArray(input);
951       }
952     }
953     for (const auto& output : op->outputs) {
954       if (!model->HasArray(output) && !model->IsOptionalArray(output)) {
955         model->GetOrCreateArray(output);
956       }
957     }
958   }
959   if (model->flags.allow_nonexistent_arrays()) {
960     for (const string& output_array : model->flags.output_arrays()) {
961       model->GetOrCreateArray(output_array);
962     }
963     for (const auto& rnn_state : model->flags.rnn_states()) {
964       model->GetOrCreateArray(rnn_state.state_array());
965       model->GetOrCreateArray(rnn_state.back_edge_source_array());
966     }
967   }
968 }
969 
CheckNoOrphanedArray(const Model & model)970 void CheckNoOrphanedArray(const Model& model) {
971   std::unordered_set<string> arrays_without_known_use;
972   for (const auto& array : model.GetArrayMap()) {
973     if (IsDiscardableArray(model, array.first)) {
974       arrays_without_known_use.insert(array.first);
975     }
976   }
977   for (const auto& op : model.operators) {
978     for (const auto& input : op->inputs) {
979       arrays_without_known_use.erase(input);
980     }
981     for (const auto& output : op->outputs) {
982       arrays_without_known_use.erase(output);
983     }
984   }
985   for (const auto& rnn_state : model.flags.rnn_states()) {
986     arrays_without_known_use.erase(rnn_state.state_array());
987     arrays_without_known_use.erase(rnn_state.back_edge_source_array());
988   }
989   if (!arrays_without_known_use.empty()) {
990     for (const auto& array : arrays_without_known_use) {
991       LOG(INFO) << "Error: Orphaned array: " << array;
992     }
993   }
994   CHECK(arrays_without_known_use.empty());
995 }
996 
FixNoOrphanedArray(Model * model)997 void FixNoOrphanedArray(Model* model) {
998   std::unordered_set<string> arrays_without_known_use;
999   for (const auto& array : model->GetArrayMap()) {
1000     arrays_without_known_use.insert(array.first);
1001   }
1002   for (const auto& op : model->operators) {
1003     for (const auto& input : op->inputs) {
1004       arrays_without_known_use.erase(input);
1005     }
1006     for (const auto& output : op->outputs) {
1007       arrays_without_known_use.erase(output);
1008     }
1009   }
1010   for (const auto& rnn_state : model->flags.rnn_states()) {
1011     arrays_without_known_use.erase(rnn_state.state_array());
1012     arrays_without_known_use.erase(rnn_state.back_edge_source_array());
1013   }
1014   for (const auto& array : arrays_without_known_use) {
1015     if (IsDiscardableArray(*model, array)) {
1016       model->EraseArray(array);
1017     }
1018   }
1019 }
1020 
1021 // Apply checks to arrays individually (for-each fashion).
1022 //
1023 // Check consistency of array fields, check name.
CheckEachArray(const Model & model)1024 void CheckEachArray(const Model& model) {
1025   for (const auto& array_entry : model.GetArrayMap()) {
1026     const auto& array = array_entry.second;
1027     // It's OK to have a buffer or an alloc, but not both.
1028     // (Since allocs are for transient arrays without a buffer).
1029     CHECK(!array->buffer || !array->alloc);
1030     if (array->buffer) {
1031       // If there is a buffer, its type should be consistent with data_type.
1032       CHECK(array->buffer->type == array->data_type);
1033       // The presence of a fixed buffer should imply the presence of a fixed
1034       // shape.
1035       CHECK(array->has_shape());
1036       // Constant buffer should has a valid shape.
1037       CheckValidShape(array->shape());
1038       // The shape flat-size should agree with the buffer length.
1039       CHECK_EQ(array->buffer->Length(),
1040                RequiredBufferSizeForShape(array->shape()));
1041     }
1042 
1043     // Check name.  Either "name_with_suffix_8", "name_with_port:3", but not
1044     // "name_with_both:3_8".
1045     const string& name = array_entry.first;
1046     auto colon_pos = name.find_first_of(":");
1047     if (colon_pos != string::npos) {
1048       CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"),
1049                string::npos)
1050           << "Array '" << name << "' has non-digit characters after colon.";
1051     }
1052     CHECK_GT(colon_pos, 0) << "Array '" << name
1053                            << "' must not start with a colon.";
1054   }
1055 }
1056 
CheckOperatorOrdering(const Model & model)1057 void CheckOperatorOrdering(const Model& model) {
1058   std::unordered_set<string> arrays_behind_us;
1059   for (const auto& array_entry : model.GetArrayMap()) {
1060     if (!GetOpWithOutput(model, array_entry.first)) {
1061       arrays_behind_us.insert(array_entry.first);
1062     }
1063   }
1064   arrays_behind_us.insert(model.optional_arrays.begin(),
1065                           model.optional_arrays.end());
1066   for (const auto& op : model.operators) {
1067     for (const auto& input : op->inputs) {
1068       if (!IsConstantParameterArray(model, input)) {
1069         CHECK(arrays_behind_us.count(input));
1070       }
1071     }
1072     for (const auto& output : op->outputs) {
1073       CHECK(!arrays_behind_us.count(output));
1074       arrays_behind_us.insert(output);
1075     }
1076   }
1077   for (const string& output_array : model.flags.output_arrays()) {
1078     CHECK(arrays_behind_us.count(output_array));
1079   }
1080 }
1081 
FixOperatorOrdering(Model * model)1082 void FixOperatorOrdering(Model* model) {
1083   std::unordered_set<string> arrays_behind_us;
1084   for (const auto& array_entry : model->GetArrayMap()) {
1085     if (!GetOpWithOutput(*model, array_entry.first)) {
1086       arrays_behind_us.insert(array_entry.first);
1087     }
1088   }
1089   arrays_behind_us.insert(model->optional_arrays.begin(),
1090                           model->optional_arrays.end());
1091   std::vector<std::unique_ptr<Operator>> old_operators;
1092   std::swap(old_operators, model->operators);
1093   std::set<std::size_t> remaining;
1094   for (std::size_t i = 0; i < old_operators.size(); i++) {
1095     remaining.insert(i);
1096   }
1097   std::unordered_map<string, string> reason_why_leftover;
1098   while (true) {
1099     bool inserted_something = false;
1100     for (const auto& i : remaining) {
1101       bool can_insert = true;
1102       auto& op = old_operators[i];
1103       CHECK(op);
1104       for (const auto& input : op->inputs) {
1105         if (!IsConstantParameterArray(*model, input) &&
1106             !arrays_behind_us.count(input)) {
1107           for (const string& output : op->outputs) {
1108             reason_why_leftover[output] = input;
1109           }
1110           can_insert = false;
1111           break;
1112         }
1113       }
1114       if (can_insert) {
1115         model->operators.emplace_back(nullptr);
1116         for (const auto& output : op->outputs) {
1117           arrays_behind_us.insert(output);
1118         }
1119         std::swap(op, model->operators.back());
1120         remaining.erase(i);
1121         inserted_something = true;
1122         break;
1123       }
1124     }
1125     if (!inserted_something) {
1126       break;
1127     }
1128   }
1129   if (!remaining.empty()) {
1130     LOG(ERROR)
1131         << "No viable ordering of operators was found. "
1132         << "Here is a 'backtrace' of at least one part of the graph that is "
1133         << "problematic. It starts with the first operator that has as "
1134         << "problematic input array, and then walks back the graph to "
1135         << "the operator that produced that input array, etc., until we find "
1136         << "the root cause:";
1137     LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT";
1138     LOG(ERROR) << "Here is the first-encountered operator with a bad input: ";
1139     const Operator* bad_op = old_operators[*remaining.begin()].get();
1140     std::unordered_set<string> bad_inputs_already_traced;
1141     // The following while(true) loop should always end with a LOG(FATAL).
1142     while (true) {
1143       LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : "
1144                  << FormatArraysList(*model, bad_op->inputs) << " -> "
1145                  << FormatArraysList(*model, bad_op->outputs);
1146       bool found_bad_output = false;
1147       string bad_output;
1148       for (const string& output : bad_op->outputs) {
1149         if (reason_why_leftover.count(output)) {
1150           found_bad_output = true;
1151           bad_output = output;
1152           break;
1153         }
1154       }
1155       CHECK(found_bad_output);
1156       const string& bad_input = reason_why_leftover[bad_output];
1157       LOG(ERROR) << "The bad input here is: " << bad_input;
1158       if (bad_inputs_already_traced.count(bad_input)) {
1159         LOG(FATAL)
1160             << "Cycle found! We already encountered that "
1161             << "input array, " << bad_input << ", earlier in the "
1162             << "above trace! We expect graphs to be acyclic, even "
1163             << "RNNs. Let us know if some graph actually needs to have "
1164             << "cycles, but first, please check if it really is "
1165             << "an *inference* graph. *Training* graphs are out-of-scope "
1166             << "for toco.";
1167       }
1168       bad_inputs_already_traced.insert(bad_input);
1169       bad_op = nullptr;
1170       for (const auto& i : remaining) {
1171         const Operator* op = old_operators[i].get();
1172         for (const string& output : op->outputs) {
1173           if (bad_input == output) {
1174             bad_op = op;
1175             break;
1176           }
1177         }
1178         if (bad_op) {
1179           break;
1180         }
1181       }
1182       if (!bad_op) {
1183         LOG(ERROR) << "And that's the root cause: "
1184                    << "that array, " << bad_input << ", isn't produced by any "
1185                    << "operator, or provided in any other way.";
1186         LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT";
1187         LOG(FATAL) << "(The above was a multi-line fatal error)";
1188       }
1189       LOG(ERROR) << "And that array is the output of the following operator:";
1190     }
1191   }
1192   CHECK(remaining.empty())
1193       << "Should never get here! In case of bad graph, "
1194       << "the above code should have generated a FATAL error already!";
1195 }
1196 
CheckInvariants(const Model & model)1197 void CheckInvariants(const Model& model) {
1198   CheckInputArraysAreNotOutputArrays(model.flags);
1199   CheckNonAsciiIOArrays(model.flags);
1200   CheckNoMissingArray(model);
1201   CheckNoOrphanedArray(model);
1202   CheckEachArray(model);
1203   CheckOperatorOrdering(model);
1204 }
1205 
CheckCountInRange(const::toco::ModelFlags::ModelCheck & model_check,const int count,const string & count_description)1206 void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check,
1207                        const int count, const string& count_description) {
1208   if (model_check.count_min() >= 0) {
1209     CHECK_GE(count, model_check.count_min())
1210         << "Mismatch in " << count_description << ": count  was " << count
1211         << ", but the specified "
1212         << (model_check.count_max() > model_check.count_min() ? "minimum"
1213                                                               : "value")
1214         << " was " << model_check.count_min() << ".";
1215   }
1216   if (model_check.count_max() > model_check.count_min()) {
1217     CHECK_LE(count, model_check.count_max())
1218         << "Mismatch in " << count_description << ": count  was " << count
1219         << ", but the specified maximum was " << model_check.count_max() << ".";
1220   }
1221 }
1222 
CheckModelCounts(const Model & model)1223 void CheckModelCounts(const Model& model) {
1224   std::unordered_multiset<OperatorType> ops_by_type;
1225   std::unordered_map<string, OperatorType> op_type_by_name;
1226   if (model.flags.model_checks_size() == 0) {
1227     return;
1228   }
1229 
1230   for (const auto& op : model.operators) {
1231     ops_by_type.insert(op->type);
1232     op_type_by_name[OperatorTypeName(op->type)] = op->type;
1233   }
1234   for (const auto& model_check : model.flags.model_checks()) {
1235     string count_type = model_check.count_type();
1236     if (count_type == "None") {
1237       continue;
1238     } else if (count_type == "Arrays") {
1239       CheckCountInRange(model_check, model.GetArrayMap().size(),
1240                         "count of arrays");
1241     } else if (count_type == "Total") {
1242       CheckCountInRange(model_check, model.operators.size(),
1243                         "count of all operator instances");
1244     } else {
1245       // The check type is not itself checked against the set of valid
1246       // operators, mainly because the enum set cannot be iterated in C++.
1247       const int found_count =
1248           op_type_by_name.count(count_type) > 0
1249               ? ops_by_type.count(op_type_by_name[count_type])
1250               : 0;
1251       CheckCountInRange(model_check, found_count,
1252                         "count of instances of " + count_type + " operator");
1253     }
1254   }
1255 }
1256 
FixEdgeArrays(Model * model)1257 void FixEdgeArrays(Model* model) {
1258   for (const string& output_array_name : model->flags.output_arrays()) {
1259     if (!GetOpWithOutput(*model, output_array_name)) {
1260       // Output has no operator producing it. Change that by inserting a copy.
1261       LOG(WARNING) << "Fixing constant output array " << output_array_name
1262                    << " by inserting a copy. This is not optimal.";
1263       string intermediate_array_name =
1264           AvailableArrayName(*model, output_array_name + "_copy");
1265       CloneArray(model, output_array_name, intermediate_array_name);
1266       InsertCopyOperator(model, intermediate_array_name, output_array_name);
1267     }
1268   }
1269 }
1270 
DedupeConstantArrays(Model * model,size_t min_size)1271 void DedupeConstantArrays(Model* model, size_t min_size) {
1272   // Walk all 0..N and compare with the remaining n+1..N.
1273   // This lets us avoid N^2 comparisons and erase duplicate arrays while
1274   // iterating.
1275   const auto& array_map = model->GetArrayMap();
1276   for (auto lhs_array_it = array_map.begin(); lhs_array_it != array_map.end();
1277        ++lhs_array_it) {
1278     const auto& lhs_array_name = lhs_array_it->first;
1279     const auto& lhs_array = *lhs_array_it->second;
1280     if (!IsConstantParameterArray(*model, lhs_array_name)) {
1281       // Not a constant array; skip.
1282       continue;
1283     }
1284     ArrayDataType final_data_type =
1285         lhs_array.final_data_type != ArrayDataType::kNone
1286             ? lhs_array.final_data_type
1287             : lhs_array.data_type;
1288     // Ignore small arrays, don't check string arrays because it is not possible
1289     // to estimate its size.
1290     if (final_data_type != ArrayDataType::kString) {
1291       size_t array_byte_size =
1292           lhs_array.buffer->Length() * ElementSize(final_data_type);
1293       if (array_byte_size < min_size) {
1294         // Too small; skip.
1295         continue;
1296       }
1297     }
1298 
1299     auto next_lhs_array_it = lhs_array_it;
1300     ++next_lhs_array_it;
1301     for (auto rhs_array_it = next_lhs_array_it;
1302          rhs_array_it != array_map.end();) {
1303       const auto& rhs_array_name = rhs_array_it->first;
1304       const auto& rhs_array = *rhs_array_it->second;
1305       ++rhs_array_it;
1306       if (!IsConstantParameterArray(*model, rhs_array_name)) {
1307         // Not a constant array; skip.
1308         continue;
1309       }
1310       if (!IsDiscardableArray(*model, rhs_array_name)) {
1311         // Can't remove the array as it's not discardable (such as an IO edge).
1312         continue;
1313       }
1314       if (!CompareConstantArrays(lhs_array, rhs_array)) {
1315         // Arrays aren't equal; skip.
1316         continue;
1317       }
1318 
1319       // Arrays can be deduped!
1320       VLOG(1) << "Deduplicating arrays; using " << lhs_array_name
1321               << " in place of " << rhs_array_name;
1322       ReplaceArrayUsage(model, rhs_array_name, lhs_array_name);
1323       // Note: rhs_array_it above is already incremented so this is safe.
1324       model->EraseArray(rhs_array_name);
1325     }
1326   }
1327 }
1328 
1329 namespace {
CopyArrayAttribs(const Array & source_array,Array * target_array)1330 void CopyArrayAttribs(const Array& source_array, Array* target_array) {
1331   target_array->data_type = source_array.data_type;
1332   target_array->final_data_type = source_array.final_data_type;
1333   target_array->copy_shape(source_array.shape());
1334 
1335   if (source_array.minmax) {
1336     target_array->GetOrCreateMinMax() = source_array.GetMinMax();
1337   } else {
1338     target_array->minmax.reset();
1339   }
1340 
1341   if (source_array.quantization_params) {
1342     target_array->GetOrCreateQuantizationParams() =
1343         source_array.GetQuantizationParams();
1344   } else {
1345     target_array->quantization_params.reset();
1346   }
1347 }
1348 }  // namespace
1349 
InsertCopyOperator(Model * model,const string & source_array_name,const string & target_array_name)1350 void InsertCopyOperator(Model* model, const string& source_array_name,
1351                         const string& target_array_name) {
1352   // Reshape to the same size. This should be a no-op.
1353   const Array& source_array = model->GetArray(source_array_name);
1354   std::vector<int> shape = source_array.shape().dims();
1355 
1356   // Drop constant data from the target array as the copy will be done at
1357   // runtime.
1358   Array& target_array = model->GetOrCreateArray(target_array_name);
1359   target_array.buffer.reset();
1360   CopyArrayAttribs(source_array, &target_array);
1361 
1362   // Insert copy operator.
1363   auto* copy_op = new TensorFlowReshapeOperator;
1364   copy_op->inputs = {
1365       source_array_name,
1366       CreateInt32Array(
1367           model, AvailableArrayName(*model, target_array_name + "_copy_shape"),
1368           shape)};
1369   copy_op->outputs = {target_array_name};
1370   if (target_array.has_shape()) {
1371     copy_op->shape = target_array.shape().dims();
1372   }
1373   model->operators.emplace_back(copy_op);
1374 }
1375 
CloneArray(Model * model,const string & source_array_name,const string & target_array_name)1376 void CloneArray(Model* model, const string& source_array_name,
1377                 const string& target_array_name) {
1378   CHECK(!model->HasArray(target_array_name));
1379   const Array& source_array = model->GetArray(source_array_name);
1380   Array& target_array = model->GetOrCreateArray(target_array_name);
1381   CopyArrayAttribs(source_array, &target_array);
1382 
1383   if (source_array.minmax) {
1384     const auto& smm = source_array.GetMinMax();
1385     auto& tmm = target_array.GetOrCreateMinMax();
1386     tmm.min = smm.min;
1387     tmm.max = smm.max;
1388   }
1389 
1390   if (source_array.quantization_params) {
1391     const auto& sqp = source_array.GetQuantizationParams();
1392     auto& tqp = target_array.GetOrCreateQuantizationParams();
1393     tqp.zero_point = sqp.zero_point;
1394     tqp.scale = sqp.scale;
1395   }
1396 
1397   target_array.data_type = source_array.data_type;
1398   target_array.final_data_type = source_array.final_data_type;
1399   target_array.copy_shape(source_array.shape());
1400 
1401   switch (source_array.data_type) {
1402     case ArrayDataType::kBool:
1403       CopyArrayBuffer<ArrayDataType::kBool>(source_array, &target_array);
1404       break;
1405     case ArrayDataType::kFloat:
1406       CopyArrayBuffer<ArrayDataType::kFloat>(source_array, &target_array);
1407       break;
1408     case ArrayDataType::kInt8:
1409       CopyArrayBuffer<ArrayDataType::kInt8>(source_array, &target_array);
1410       break;
1411     case ArrayDataType::kUint8:
1412       CopyArrayBuffer<ArrayDataType::kUint8>(source_array, &target_array);
1413       break;
1414     case ArrayDataType::kInt16:
1415       CopyArrayBuffer<ArrayDataType::kInt16>(source_array, &target_array);
1416       break;
1417     case ArrayDataType::kUint16:
1418       CopyArrayBuffer<ArrayDataType::kUint16>(source_array, &target_array);
1419       break;
1420     case ArrayDataType::kInt32:
1421       CopyArrayBuffer<ArrayDataType::kInt32>(source_array, &target_array);
1422       break;
1423     case ArrayDataType::kUint32:
1424       CopyArrayBuffer<ArrayDataType::kUint32>(source_array, &target_array);
1425       break;
1426     case ArrayDataType::kInt64:
1427       CopyArrayBuffer<ArrayDataType::kInt64>(source_array, &target_array);
1428       break;
1429     case ArrayDataType::kUint64:
1430       CopyArrayBuffer<ArrayDataType::kUint64>(source_array, &target_array);
1431       break;
1432     case ArrayDataType::kString:
1433       CopyArrayBuffer<ArrayDataType::kString>(source_array, &target_array);
1434       break;
1435     case ArrayDataType::kComplex64:
1436       CopyArrayBuffer<ArrayDataType::kComplex64>(source_array, &target_array);
1437       break;
1438     default:
1439       LOG(FATAL) << "Unsupported data type: "
1440                  << ArrayDataTypeName(source_array.data_type);
1441       return;
1442   }
1443 }
1444 
MakeArrayDims(int num_dims,int batch,int height,int width,int depth,std::vector<int> * out_dims)1445 void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
1446                    std::vector<int>* out_dims) {
1447   CHECK(out_dims->empty());
1448   if (num_dims == 0) {
1449     return;
1450   } else if (num_dims == 1) {
1451     CHECK_EQ(batch, 1);
1452     *out_dims = {depth};
1453   } else if (num_dims == 2) {
1454     *out_dims = {batch, depth};
1455   } else if (num_dims == 3) {
1456     CHECK_EQ(batch, 1);
1457     *out_dims = {height, width, depth};
1458   } else if (num_dims == 4) {
1459     *out_dims = {batch, height, width, depth};
1460   } else {
1461     LOG(FATAL) << "Should not get here: " << num_dims;
1462   }
1463 }
1464 
CreateOrCheckRnnStateArray(const string & name,int size,int state_num_dims,Model * model)1465 void CreateOrCheckRnnStateArray(const string& name, int size,
1466                                 int state_num_dims, Model* model) {
1467   int batch = 1;
1468   int num_dims = -1;
1469   if (state_num_dims > 0) {
1470     num_dims = state_num_dims;
1471   } else {
1472     // state_num_dims is not given. We will infer it from an input tensor.
1473     for (const auto& input_array : model->flags.input_arrays()) {
1474       // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
1475       // a better match by name.
1476       if (input_array.name() == name || num_dims == -1) {
1477         num_dims = input_array.shape().dims_size();
1478         if (num_dims > 0) {
1479           batch = input_array.shape().dims(0);
1480         }
1481       }
1482     }
1483   }
1484   Array& array = model->GetOrCreateArray(name);
1485   if (array.has_shape()) {
1486     num_dims = array.shape().dimensions_count();
1487   }
1488   if (!array.has_shape() && num_dims >= 0) {
1489     Shape* shape = array.mutable_shape();
1490     std::vector<int> dims;
1491     MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
1492     *shape->mutable_dims() = dims;
1493   }
1494 }
1495 
ResolveModelFlags(const ModelFlags & model_flags,Model * model)1496 void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
1497   // Merge info about input_arrays from model_flags into model->flags
1498   for (const auto& specified_input_array : model_flags.input_arrays()) {
1499     toco::InputArray* dst_input_array = nullptr;
1500     for (int i = 0; i < model->flags.input_arrays_size(); i++) {
1501       toco::InputArray* candidate_dst_input_array =
1502           model->flags.mutable_input_arrays(i);
1503       if (candidate_dst_input_array->name() == specified_input_array.name()) {
1504         // specified_input_array from model_flags maps to dst_input_array
1505         // in model->flags
1506         dst_input_array = candidate_dst_input_array;
1507         break;
1508       }
1509     }
1510     if (!dst_input_array) {
1511       // Specified_input_array from model_flags is not found in model->flags.
1512       // Match a name-less specified input array when there can be no ambiguity
1513       // as there is only 1 input array.
1514       if (model->flags.input_arrays_size() == 1 &&
1515           model_flags.input_arrays_size() == 1 &&
1516           !specified_input_array.has_name()) {
1517         dst_input_array = model->flags.mutable_input_arrays(0);
1518       }
1519     }
1520     if (!dst_input_array) {
1521       // Still no match, so create a new input array to copy
1522       // specified_input_array into.
1523       dst_input_array = model->flags.add_input_arrays();
1524       dst_input_array->set_name(specified_input_array.name());
1525     }
1526 
1527 #define RESOLVE_MODEL_FLAG(field_name)                                       \
1528   if (specified_input_array.has_##field_name()) {                            \
1529     if (dst_input_array->has_##field_name()) {                               \
1530       QCHECK_EQ(dst_input_array->field_name(),                               \
1531                 specified_input_array.field_name())                          \
1532           << "For input array '" << dst_input_array->name() << "', "         \
1533           << "specified " #field_name " flag with value: "                   \
1534           << specified_input_array.field_name()                              \
1535           << " does not agree with already defined " #field_name             \
1536              " of this model, with value: "                                  \
1537           << specified_input_array.field_name();                             \
1538     } else {                                                                 \
1539       dst_input_array->set_##field_name(specified_input_array.field_name()); \
1540     }                                                                        \
1541   }
1542     RESOLVE_MODEL_FLAG(std_value);
1543     RESOLVE_MODEL_FLAG(mean_value);
1544 #undef RESOLVE_MODEL_FLAG
1545 
1546     if (specified_input_array.has_shape()) {
1547       if (dst_input_array->has_shape()) {
1548         QCHECK_EQ(specified_input_array.shape().dims_size(),
1549                   dst_input_array->shape().dims_size())
1550             << "For input array '" << specified_input_array.name() << "', "
1551             << "size of specified input shape flag with size: "
1552             << specified_input_array.shape().dims_size()
1553             << " does not agree with already defined input shape"
1554                " of this model, with size: "
1555             << dst_input_array->shape().dims_size();
1556         // We treat the first dimension as a special case, since it is often
1557         // a batch size and the input_shape flag is effectively overriding
1558         // the model.
1559         for (int i = 1; i < specified_input_array.shape().dims_size(); i++) {
1560           QCHECK_EQ(specified_input_array.shape().dims(i),
1561                     dst_input_array->shape().dims(i))
1562               << "At dimension number " << i << " of input array "
1563               << specified_input_array.name() << ", the specified shape's "
1564               << "dimension flag with dimension: "
1565               << specified_input_array.shape().dims(i)
1566               << " does not agree with already defined shape"
1567               << " of this model, with dimension: "
1568               << dst_input_array->shape().dims(i);
1569         }
1570       } else {
1571         *dst_input_array->mutable_shape() = specified_input_array.shape();
1572       }
1573     }
1574 
1575     if (specified_input_array.has_data_type()) {
1576       QCHECK(!dst_input_array->has_data_type());
1577       dst_input_array->set_data_type(specified_input_array.data_type());
1578     }
1579   }
1580 
1581   if (model_flags.output_arrays_size() > 0) {
1582     model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays());
1583   }
1584 
1585 #define RESOLVE_MODEL_FLAG(name)                                           \
1586   if (model_flags.has_##name()) {                                          \
1587     if (model->flags.has_##name()) {                                       \
1588       QCHECK_EQ(model_flags.name(), model->flags.name())                   \
1589           << "Specified " #name " flag with value: " << model_flags.name() \
1590           << " does not agree with already defined " #name                 \
1591              " of this model, with value: "                                \
1592           << model->flags.name();                                          \
1593     } else {                                                               \
1594       model->flags.set_##name(model_flags.name());                         \
1595     }                                                                      \
1596   }
1597 
1598   RESOLVE_MODEL_FLAG(variable_batch)
1599 
1600 #undef RESOLVE_MODEL_FLAG
1601 
1602   if (!model_flags.rnn_states().empty()) {
1603     model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states());
1604   }
1605 
1606   if (model->flags.model_checks_size() == 0) {
1607     model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks());
1608   }
1609 
1610   QCHECK_GT(model->flags.output_arrays_size(), 0)
1611       << "This model does not define output arrays, so a "
1612          "--output_arrays flag must be given on the command-line.";
1613 
1614   for (auto& input_array_proto : *model->flags.mutable_input_arrays()) {
1615     auto& input_array = model->GetOrCreateArray(input_array_proto.name());
1616     if (input_array_proto.has_data_type()) {
1617       const ArrayDataType specified_type =
1618           ConvertIODataTypeToArrayDataType(input_array_proto.data_type());
1619       QCHECK(specified_type != ArrayDataType::kNone);
1620       if (input_array.data_type != ArrayDataType::kNone) {
1621         QCHECK(specified_type == input_array.data_type)
1622             << "For input array " << input_array_proto.name()
1623             << " the specified input data type "
1624             << IODataType_Name(input_array_proto.data_type())
1625             << " conflicts with the existing type.";
1626       }
1627       input_array.data_type = specified_type;
1628     }
1629 
1630     if (input_array.data_type == ArrayDataType::kNone) {
1631       // We start out with a float input array;
1632       // that may get replaced by a uint8 array later, by
1633       // MakeInitialDequantizeOp.
1634       input_array.data_type = ArrayDataType::kFloat;
1635     }
1636 
1637     // Compare/merge the model->flags describing the input_shape with
1638     // the actual input array's shape.
1639     if (!input_array.has_shape()) {
1640       if (input_array_proto.has_shape()) {
1641         auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
1642         CheckValidShapeDimensions(input_array_proto.shape().dims());
1643         for (const auto& dim : input_array_proto.shape().dims()) {
1644           input_array_dims.push_back(dim);
1645         }
1646       }
1647     } else {
1648       if (input_array_proto.has_shape()) {
1649         // If an input shape was specified on the flags ensure that it matches
1650         // the actual shape in the model.
1651         const auto& input_array_dims =
1652             *input_array.mutable_shape()->mutable_dims();
1653         CHECK_EQ(input_array_dims.size(),
1654                  input_array_proto.shape().dims_size());
1655         for (int i = 0; i < input_array_dims.size(); i++) {
1656           CHECK_EQ(input_array_dims[i], input_array_proto.shape().dims(i));
1657         }
1658       } else {
1659         for (int i = 0; i < input_array.shape().dimensions_count(); i++) {
1660           input_array_proto.mutable_shape()->add_dims(
1661               input_array.shape().dims(i));
1662         }
1663       }
1664     }
1665 
1666     const float mean_value = input_array_proto.mean_value();
1667     const float std_value = input_array_proto.std_value();
1668     MinMax input_minmax;
1669     float qmin = 0, qmax = 255;
1670     if (input_array.data_type == ArrayDataType::kInt16) {
1671       qmin = -32768;
1672       qmax = 32767;
1673     }
1674     input_minmax.min = (qmin - mean_value) / std_value;
1675     input_minmax.max = (qmax - mean_value) / std_value;
1676     if (!input_array.minmax) {
1677       input_array.GetOrCreateMinMax() = input_minmax;
1678     }
1679   }
1680 
1681   // Creation of the RNN state arrays
1682   for (const auto& rnn_state : model->flags.rnn_states()) {
1683     CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
1684                                rnn_state.num_dims(), model);
1685   }
1686 
1687   model->flags.set_change_concat_input_ranges(
1688       model_flags.change_concat_input_ranges());
1689   model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
1690   model->flags.set_allow_nonexistent_arrays(
1691       model_flags.allow_nonexistent_arrays());
1692 
1693   CHECK(!model->flags.has_arrays_extra_info());
1694   *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info();
1695 }
1696 
CheckIsReadyForQuantization(const Model & model)1697 void CheckIsReadyForQuantization(const Model& model) {
1698   for (const auto& op : model.operators) {
1699     for (const auto& input : op->inputs) {
1700       const auto& input_array = model.GetArray(input);
1701       if (input_array.data_type != ArrayDataType::kFloat) {
1702         // The array is not floats, no quantization needed.
1703         continue;
1704       }
1705       if (input_array.minmax) {
1706         // The array has minmax, we're good.
1707         continue;
1708       }
1709       if (input_array.buffer) {
1710         // The array has a constant buffer, so we can
1711         // fall back to computing the minmax from actual array entries
1712         // (with a WARNING about possible accuracy implications).
1713         continue;
1714       }
1715       LOG(FATAL)
1716           << "Array " << input << ", which is an input to the "
1717           << HelpfulOperatorTypeName(*op) << " operator producing the output "
1718           << "array " << op->outputs[0] << ", is lacking min/max data, "
1719           << "which is necessary for quantization. If accuracy matters, either "
1720           << "target a non-quantized output format, or run quantized training "
1721           << "with your model from a floating point checkpoint to change the "
1722           << "input graph to contain min/max information. If you don't care "
1723           << "about accuracy, you can pass --default_ranges_min= and "
1724           << "--default_ranges_max= for easy experimentation.";
1725     }
1726   }
1727 }
1728 
ElementSize(ArrayDataType data_type)1729 int ElementSize(ArrayDataType data_type) {
1730   switch (data_type) {
1731     case ArrayDataType::kBool:
1732       return sizeof(bool);
1733     case ArrayDataType::kFloat:
1734       return 4;
1735     case ArrayDataType::kInt8:
1736       return 1;
1737     case ArrayDataType::kUint8:
1738       return 1;
1739     case ArrayDataType::kInt16:
1740       return 2;
1741     case ArrayDataType::kUint16:
1742       return 2;
1743     case ArrayDataType::kInt32:
1744       return 4;
1745     case ArrayDataType::kUint32:
1746       return 4;
1747     case ArrayDataType::kInt64:
1748       return 8;
1749     case ArrayDataType::kUint64:
1750       return 8;
1751     case ArrayDataType::kComplex64:
1752       return 8;
1753 
1754     // Usually not critical limitation because strings are only input and/or
1755     // output.
1756     case ArrayDataType::kString:
1757       LOG(FATAL) << "Transient arrays with strings are not supported yet";
1758       return 0;
1759     default:
1760       LOG(FATAL) << "Unknown data_type = " << static_cast<int>(data_type);
1761       return 0;
1762   }
1763 }
1764 
DropMinMax(Model * model,const string & array_name)1765 void DropMinMax(Model* model, const string& array_name) {
1766   auto& array = model->GetArray(array_name);
1767   if (!!array.minmax) {
1768     LOG(WARNING) << "Dropping MinMax information in array " << array_name
1769                  << ". Expect inaccuracy in quantized inference.";
1770     array.minmax = nullptr;
1771   }
1772 }
1773 
IsAllocatableTransientArray(const Model & model,const string & array_name)1774 bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
1775   // Optional array is not transient
1776   if (model.IsOptionalArray(array_name)) return false;
1777   // The model's input and output arrays are externally allocated.
1778   // They are not transient arrays.
1779   if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
1780     return false;
1781   }
1782   const auto& array = &model.GetArray(array_name);
1783   // An array with a constant buffer isn't a transient array.
1784   if (!!array->buffer) {
1785     return false;
1786   }
1787   // An array without shape isn't allocatable.
1788   if (!array->has_shape()) {
1789     return false;
1790   }
1791 
1792   // The size of string tensors is rarely known ahead of time, so all transient
1793   // tensors of this type will need to be dynamically allocated.
1794   if (array->final_data_type == ArrayDataType::kString ||
1795       array->data_type == ArrayDataType::kString) {
1796     return false;
1797   }
1798 
1799   return true;
1800 }
1801 
AvailableArrayName(const Model & model,const string & name)1802 string AvailableArrayName(const Model& model, const string& name) {
1803   string sanitized_name = SanitizeNameForTFNode(name);
1804   if (!model.HasArray(sanitized_name) &&
1805       !model.IsOptionalArray(sanitized_name)) {
1806     return sanitized_name;
1807   }
1808   const int kNumSuffixesToTry = 1000;
1809   for (int i = 0; i < kNumSuffixesToTry; i++) {
1810     const string& name_with_suffix =
1811         toco::port::StringF("%s_%d", sanitized_name, i);
1812     if (!model.HasArray(name_with_suffix) &&
1813         !model.IsOptionalArray(name_with_suffix)) {
1814       return name_with_suffix;
1815     }
1816   }
1817   LOG(FATAL) << "Could not find an available array name starting with "
1818              << sanitized_name << ". Tried " << kNumSuffixesToTry
1819              << " suffixes, all were taken!";
1820   return "";
1821 }
1822 
ShapeToString(const Shape & shape)1823 string ShapeToString(const Shape& shape) {
1824   if (shape.dimensions_count() == 0) {
1825     return "[]";
1826   }
1827 
1828   return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]");
1829 }
1830 
PrintArrayShape(Model * model,const string & name)1831 void PrintArrayShape(Model* model, const string& name) {
1832   if (!model->GetArray(name).has_shape()) {
1833     LOG(INFO) << name << " has no shape";
1834     return;
1835   }
1836   LOG(INFO) << name
1837             << " has shape: " << ShapeToString(model->GetArray(name).shape());
1838 }
1839 
IsArrayFullyConnectedWeights(const Model & model,const string & name)1840 bool IsArrayFullyConnectedWeights(const Model& model, const string& name) {
1841   bool is_fc_weights = false;
1842   bool is_something_else = false;
1843   for (const auto& op : model.operators) {
1844     for (int input_index = 0; input_index < op->inputs.size(); input_index++) {
1845       if (op->inputs[input_index] == name) {
1846         if (op->type == OperatorType::kFullyConnected && input_index == 1) {
1847           is_fc_weights = true;
1848         } else {
1849           is_something_else = true;
1850         }
1851       }
1852     }
1853   }
1854   CHECK(!(is_fc_weights && is_something_else));
1855   return is_fc_weights;
1856 }
1857 
CreateInt32Array(Model * model,const string & param_name,const std::vector<int> & value)1858 string CreateInt32Array(Model* model, const string& param_name,
1859                         const std::vector<int>& value) {
1860   auto param_array_name = AvailableArrayName(*model, param_name);
1861   auto& param_array = model->GetOrCreateArray(param_array_name);
1862   param_array.mutable_shape()->ReplaceDims({static_cast<int>(value.size())});
1863   param_array.data_type = ArrayDataType::kInt32;
1864   auto& param_array_data =
1865       param_array.GetMutableBuffer<ArrayDataType::kInt32>().data;
1866   param_array_data.resize(RequiredBufferSizeForShape(param_array.shape()));
1867   for (int i = 0; i < value.size(); ++i) {
1868     param_array_data[i] = value[i];
1869   }
1870   return param_array_name;
1871 }
1872 
EstimateArithmeticOpsCount(const Model & model,const Operator & op,int64 * result)1873 bool EstimateArithmeticOpsCount(const Model& model, const Operator& op,
1874                                 int64* result) {
1875   switch (op.type) {
1876     case OperatorType::kFullyConnected:
1877     case OperatorType::kConv:
1878     case OperatorType::kDepthwiseConv: {
1879       const auto& output_array = model.GetArray(op.outputs[0]);
1880       const auto& weights_array = model.GetArray(op.inputs[1]);
1881       if (!output_array.has_shape() || !weights_array.has_shape()) {
1882         return false;
1883       }
1884       int64 cols = 1;
1885       for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
1886         cols *= output_array.shape().dims(i);
1887       }
1888       const int64 cost_per_col =
1889           2 * RequiredBufferSizeForShape(weights_array.shape());
1890       *result = cost_per_col * cols;
1891       if (op.inputs.size() > 2) {
1892         // There is a bias vector. One more op per output value.
1893         *result += RequiredBufferSizeForShape(output_array.shape());
1894       }
1895       break;
1896     }
1897     case OperatorType::kAdd:
1898     case OperatorType::kSub:
1899     case OperatorType::kMul: {
1900       const auto& output_array = model.GetArray(op.outputs[0]);
1901       if (!output_array.has_shape()) {
1902         return false;
1903       }
1904       *result = RequiredBufferSizeForShape(output_array.shape());
1905       break;
1906     }
1907     case OperatorType::kAddN: {
1908       const auto& output_array = model.GetArray(op.outputs[0]);
1909       if (!output_array.has_shape()) {
1910         return false;
1911       }
1912       // AddN cost is roughly the same cost as N-1 Adds.
1913       const int64 num_adds = op.inputs.size() - 1;
1914       *result = num_adds * RequiredBufferSizeForShape(output_array.shape());
1915       break;
1916     }
1917     case OperatorType::kLogistic:
1918     case OperatorType::kSoftmax:
1919     case OperatorType::kLogSoftmax:
1920     case OperatorType::kTanh: {
1921       const auto& output_array = model.GetArray(op.outputs[0]);
1922       if (!output_array.has_shape()) {
1923         return false;
1924       }
1925       // As a very rough ballpark, the cost of evaluating a math function
1926       // such as tanh or logistic is about 32 multiplications, and about as
1927       // many additions/subtractions. (Just a power-of-two order-of-magnitude
1928       // from looking at actual implementations that we use in runtime/ code).
1929       *result = 64 * RequiredBufferSizeForShape(output_array.shape());
1930       break;
1931     }
1932     case OperatorType::kMaxPool: {
1933       const auto& maxpool = *static_cast<const MaxPoolOperator*>(&op);
1934       const auto& output_array = model.GetArray(op.outputs[0]);
1935       if (!output_array.has_shape()) {
1936         return false;
1937       }
1938       *result = RequiredBufferSizeForShape(output_array.shape()) *
1939                 maxpool.kheight * maxpool.kwidth;
1940       break;
1941     }
1942     case OperatorType::kAveragePool: {
1943       const auto& avgpool = *static_cast<const AveragePoolOperator*>(&op);
1944       const auto& output_array = model.GetArray(op.outputs[0]);
1945       if (!output_array.has_shape()) {
1946         return false;
1947       }
1948       *result = RequiredBufferSizeForShape(output_array.shape()) *
1949                 avgpool.kheight * avgpool.kwidth;
1950       break;
1951     }
1952     case OperatorType::kL2Pool: {
1953       const auto* maxpool = static_cast<const MaxPoolOperator*>(&op);
1954       const auto& output_array = model.GetArray(op.outputs[0]);
1955       if (!output_array.has_shape()) {
1956         return false;
1957       }
1958       // The sum of squares requires (kheight*kwidth) multiply-adds,
1959       // and then there is the sqrt which we ballpark at 32 ops.
1960       const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
1961       *result = RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
1962       break;
1963     }
1964     case OperatorType::kL2Normalization: {
1965       const auto& output_array = model.GetArray(op.outputs[0]);
1966       if (!output_array.has_shape()) {
1967         return false;
1968       }
1969       // Computing the squared L2 norm is N multiply-adds so 2N ops,
1970       // then the single inverse-sqrt is negligible, then we multiply each
1971       // value by the resulting multiplier, so an extra N ops. count 3N ops.
1972       *result = 3 * RequiredBufferSizeForShape(output_array.shape());
1973       break;
1974     }
1975     default:
1976       *result = 0;
1977       break;
1978   }
1979   return true;
1980 }
1981 
EstimateArithmeticOpsCount(const Model & model,int64 * result)1982 bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
1983   int64 total = 0;
1984   for (const auto& op : model.operators) {
1985     int64 num_ops;
1986     if (!EstimateArithmeticOpsCount(model, *op, &num_ops)) {
1987       return false;
1988     }
1989     total += num_ops;
1990   }
1991   *result = total;
1992   return true;
1993 }
1994 
FormattedNumber(int64 x)1995 string FormattedNumber(int64 x) {
1996   const int64 million = 1000000;
1997   const int64 billion = 1000000000;
1998   if (x < 10000) {
1999     return toco::port::StringF("%d ", x);
2000   } else if (x < billion) {
2001     return toco::port::StringF("%.3f M", static_cast<double>(x) / million);
2002   } else {
2003     return toco::port::StringF("%.3f G", static_cast<double>(x) / billion);
2004   }
2005 }
2006 
GetShuffleShape(AxesOrder input_axes_order,AxesOrder output_axes_order,std::vector<int> * shuffle)2007 void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
2008                      std::vector<int>* shuffle) {
2009   CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));
2010   shuffle->resize(4);
2011   for (int i = 0; i < 4; i++) {
2012     (*shuffle)[i] = i;
2013   }
2014   if (input_axes_order == output_axes_order) {
2015     // nothing to do
2016   } else if (AxesCount(input_axes_order) == 2) {
2017     shuffle->resize(2);
2018     (*shuffle)[0] = 1;
2019     (*shuffle)[1] = 0;
2020   } else if (input_axes_order == AxesOrder::kOHWI &&
2021              output_axes_order == AxesOrder::kHWIO) {
2022     // 3210 <- 3210
2023     // HWIO <- OHWI
2024     *shuffle = {1, 2, 3, 0};
2025   } else if (input_axes_order == AxesOrder::kHWIO &&
2026              output_axes_order == AxesOrder::kOHWI) {
2027     // 3210 <- 3210
2028     // OHWI <- HWIO
2029     *shuffle = {3, 0, 1, 2};
2030   } else if (input_axes_order == AxesOrder::kOHWI &&
2031              output_axes_order == AxesOrder::kHWOI) {
2032     *shuffle = {1, 2, 0, 3};
2033   } else {
2034     LOG(FATAL) << "Bad shuffle";
2035   }
2036 }
2037 
ExtendShuffle(const std::vector<int> & input_shuffle,int newdim,std::vector<int> * extended_shuffle)2038 void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
2039                    std::vector<int>* extended_shuffle) {
2040   *extended_shuffle = input_shuffle;
2041   CHECK(newdim >= input_shuffle.size());
2042   const int pad_size = newdim - input_shuffle.size();
2043   extended_shuffle->resize(newdim);
2044   for (int i = 0; i < pad_size; i++) {
2045     (*extended_shuffle)[i] = i;
2046   }
2047   for (int i = pad_size; i < newdim; i++) {
2048     (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size;
2049   }
2050 }
2051 
ShuffleDims(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,Shape * output_shape)2052 void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
2053                  AxesOrder output_axes_order, Shape* output_shape) {
2054   if (input_axes_order == AxesOrder::kHWIM &&
2055       output_axes_order == AxesOrder::k1HWO) {
2056     // This special case isn't just a permutation, the IM pair of dims get
2057     // merged into the 3 dim, so we have to special-case it.
2058     *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1),
2059                            input_shape.dims(3) * input_shape.dims(2)});
2060   } else {
2061     std::vector<int> shuffle;
2062     GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
2063     std::vector<int>* output_dims = output_shape->mutable_dims();
2064     output_dims->resize(input_shape.dimensions_count());
2065     for (int i = 0; i < input_shape.dimensions_count(); i++) {
2066       (*output_dims)[i] = input_shape.dims(shuffle[i]);
2067     }
2068   }
2069 }
2070 
2071 template <typename T>
ShuffleArrayTemplate(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const T * input_data,T * output_data)2072 void ShuffleArrayTemplate(const Shape& input_shape, AxesOrder input_axes_order,
2073                           AxesOrder output_axes_order,
2074                           const Shape& output_shape, const T* input_data,
2075                           T* output_data) {
2076   if (input_axes_order == AxesOrder::kHWIM &&
2077       output_axes_order == AxesOrder::k1HWO) {
2078     // This special case isn't just a permutation, the IM pair of dims get
2079     // merged into the O dim, so we have to special-case it. Fortunately,
2080     // as far as array shuffling is concerned, it's just the identity
2081     // transformation.
2082     memcpy(output_data, input_data,
2083            RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0]));
2084     return;
2085   }
2086   CHECK(input_shape.dimensions_count() == output_shape.dimensions_count());
2087   const int dim = input_shape.dimensions_count();
2088   CHECK_LE(dim, 4);
2089   std::vector<int> shuffle;
2090   GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
2091   CHECK(shuffle.size() >= dim);
2092   for (int i = 0; i < dim; i++) {
2093     CHECK(shuffle[i] >= 0 && shuffle[i] < dim);
2094     CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i));
2095   }
2096   Shape extended_input_shape = input_shape;
2097   ExtendShape(&extended_input_shape, 4);
2098   Shape extended_output_shape = output_shape;
2099   ExtendShape(&extended_output_shape, 4);
2100   std::vector<int> extended_shuffle;
2101   ExtendShuffle(shuffle, 4, &extended_shuffle);
2102 
2103   const std::vector<int>& extended_input_dims = extended_input_shape.dims();
2104   const std::vector<int>& extended_output_dims = extended_output_shape.dims();
2105 
2106   // TODO(starka): Rework to handle different numbers of dimensions.
2107   int input_strides[4];
2108   input_strides[3] = 1;
2109   input_strides[2] = extended_input_dims[3];
2110   input_strides[1] = input_strides[2] * extended_input_dims[2];
2111   input_strides[0] = input_strides[1] * extended_input_dims[1];
2112   const int input_stride_0 = input_strides[extended_shuffle[3]];
2113   const int input_stride_1 = input_strides[extended_shuffle[2]];
2114   const int input_stride_2 = input_strides[extended_shuffle[1]];
2115   const int input_stride_3 = input_strides[extended_shuffle[0]];
2116 
2117   const int output_size_0 = extended_output_dims[3];
2118   const int output_size_1 = extended_output_dims[2];
2119   const int output_size_2 = extended_output_dims[1];
2120   const int output_size_3 = extended_output_dims[0];
2121   const int output_stride_0 = 1;
2122   const int output_stride_1 = output_size_0;
2123   const int output_stride_2 = output_stride_1 * output_size_1;
2124   const int output_stride_3 = output_stride_2 * output_size_2;
2125 
2126   for (int i3 = 0; i3 < output_size_3; i3++) {
2127     const T* const input_ptr_3 = input_data + i3 * input_stride_3;
2128     T* const output_ptr_3 = output_data + i3 * output_stride_3;
2129     for (int i2 = 0; i2 < output_size_2; i2++) {
2130       const T* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2;
2131       T* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2;
2132       for (int i1 = 0; i1 < output_size_1; i1++) {
2133         const T* input_ptr = input_ptr_2 + i1 * input_stride_1;
2134         T* output_ptr = output_ptr_2 + i1 * output_stride_1;
2135         T* const output_ptr_end = output_ptr + output_size_0 * output_stride_0;
2136         while (output_ptr != output_ptr_end) {
2137           *output_ptr = *input_ptr;
2138           input_ptr += input_stride_0;
2139           output_ptr += output_stride_0;
2140         }
2141       }
2142     }
2143   }
2144 }
2145 
ShuffleArray(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const uint8 * input_data,uint8 * output_data)2146 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
2147                   AxesOrder output_axes_order, const Shape& output_shape,
2148                   const uint8* input_data, uint8* output_data) {
2149   ShuffleArrayTemplate<uint8>(input_shape, input_axes_order, output_axes_order,
2150                               output_shape, input_data, output_data);
2151 }
2152 
ShuffleArray(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const float * input_data,float * output_data)2153 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
2154                   AxesOrder output_axes_order, const Shape& output_shape,
2155                   const float* input_data, float* output_data) {
2156   ShuffleArrayTemplate<float>(input_shape, input_axes_order, output_axes_order,
2157                               output_shape, input_data, output_data);
2158 }
2159 
AxesCount(AxesOrder axes_order)2160 int AxesCount(AxesOrder axes_order) {
2161   switch (axes_order) {
2162     case AxesOrder::kOneAxis:
2163       return 1;
2164     case AxesOrder::kRC:
2165       return 2;
2166     case AxesOrder::kCR:
2167       return 2;
2168     case AxesOrder::kHWIO:
2169       return 4;
2170     case AxesOrder::kOHWI:
2171       return 4;
2172     case AxesOrder::kHWIM:
2173       return 4;
2174     case AxesOrder::k1HWO:
2175       return 4;
2176     case AxesOrder::kNHWC:
2177       return 4;
2178     case AxesOrder::kHWOI:
2179       return 4;
2180     default:
2181       LOG(FATAL) << "Bad AxesOrder";
2182       return 0;
2183   }
2184 }
2185 
IsDiscardableArray(const Model & model,const string & array_name)2186 bool IsDiscardableArray(const Model& model, const string& array_name) {
2187   if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
2188     return false;
2189   }
2190   for (const auto& rnn_state : model.flags.rnn_states()) {
2191     if (!rnn_state.discardable()) {
2192       if (array_name == rnn_state.state_array()) {
2193         return false;
2194       }
2195       if (array_name == rnn_state.back_edge_source_array()) {
2196         return false;
2197       }
2198     }
2199   }
2200   return true;
2201 }
2202 
ReshapeIsEquivalentToTranspose(const Model & model,const TensorFlowReshapeOperator * op,bool allow_extra_unary_dims)2203 bool ReshapeIsEquivalentToTranspose(const Model& model,
2204                                     const TensorFlowReshapeOperator* op,
2205                                     bool allow_extra_unary_dims) {
2206   CHECK(!op->shape.empty());
2207   CHECK(model.HasArray(op->inputs[0]));
2208   CHECK(model.HasArray(op->outputs[0]));
2209 
2210   const auto& input_array = model.GetArray(op->inputs[0]);
2211   const auto& output_array = model.GetArray(op->outputs[0]);
2212 
2213   CHECK(input_array.has_shape());
2214   CHECK(output_array.has_shape());
2215 
2216   std::vector<int> in_shape = input_array.shape().dims();
2217   std::vector<int> out_shape = output_array.shape().dims();
2218 
2219   // If the reshape changes the number of dimensions so it cannot be interpreted
2220   // as a transpose.
2221   if (!allow_extra_unary_dims && in_shape.size() != out_shape.size()) {
2222     return false;
2223   }
2224 
2225   in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1),
2226                  in_shape.end());
2227   out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1),
2228                   out_shape.end());
2229   return in_shape == out_shape;
2230 }
2231 
CheckFinalDataTypesSatisfied(const Model & model)2232 void CheckFinalDataTypesSatisfied(const Model& model) {
2233   for (const auto& array_entry : model.GetArrayMap()) {
2234     const auto& array = *array_entry.second;
2235     if (array.data_type == ArrayDataType::kBool) {
2236       // Boolean values are never quantized.
2237       continue;
2238     }
2239 
2240     // If the final data type is int16, the data type may be float, for example
2241     // after dequantization.
2242     if (array.final_data_type != ArrayDataType::kNone &&
2243         array.final_data_type != ArrayDataType::kInt16) {
2244       CHECK(array.data_type == array.final_data_type)
2245           << "Array \"" << array_entry.first
2246           << "\" has mis-matching actual and final data types (data_type="
2247           << ArrayDataTypeName(array.data_type)
2248           << ", final_data_type=" << ArrayDataTypeName(array.final_data_type)
2249           << ").";
2250     }
2251   }
2252 }
2253 
ConvertIODataTypeToArrayDataType(IODataType type)2254 ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
2255   switch (type) {
2256     case FLOAT:
2257       return ArrayDataType::kFloat;
2258     case QUANTIZED_UINT8:
2259       return ArrayDataType::kUint8;
2260     case INT8:
2261       return ArrayDataType::kInt8;
2262     case QUANTIZED_INT16:
2263       return ArrayDataType::kInt16;
2264     case INT32:
2265       return ArrayDataType::kInt32;
2266     case INT64:
2267       return ArrayDataType::kInt64;
2268     case BOOL:
2269       return ArrayDataType::kBool;
2270     case STRING:
2271       return ArrayDataType::kString;
2272     case COMPLEX64:
2273       return ArrayDataType::kComplex64;
2274     default:
2275       return ArrayDataType::kNone;
2276   }
2277 }
2278 
FinishBuildingRNNStates(Model * model)2279 void FinishBuildingRNNStates(Model* model) {
2280   for (const auto& rnn_state : model->flags.rnn_states()) {
2281     if (!model->HasArray(rnn_state.back_edge_source_array()) ||
2282         !model->HasArray(rnn_state.state_array())) {
2283       CHECK(model->HasArray(rnn_state.back_edge_source_array()));
2284       CHECK(model->HasArray(rnn_state.state_array()));
2285       continue;
2286     }
2287     const auto& src_array = model->GetArray(rnn_state.back_edge_source_array());
2288     auto& dst_array = model->GetArray(rnn_state.state_array());
2289     if (src_array.data_type == ArrayDataType::kNone &&
2290         dst_array.data_type == ArrayDataType::kNone) {
2291       dst_array.data_type = ArrayDataType::kFloat;
2292     }
2293   }
2294 }
2295 
2296 // Returns the array names that match the ArraysExtraInfo's name and
2297 // name_regexp. The regexp match is for a full match.
ScanArrayNames(const Model & model,const toco::ArraysExtraInfo_Entry & entry)2298 std::unordered_set<string> ScanArrayNames(
2299     const Model& model, const toco::ArraysExtraInfo_Entry& entry) {
2300   std::unordered_set<string> matches;
2301   if (model.HasArray(entry.name())) {
2302     matches.insert(entry.name());
2303   }
2304   if (!entry.name_regexp().empty()) {
2305     const auto& arrays = model.GetArrayMap();
2306     const RE2 name_regexp = {entry.name_regexp()};
2307     for (auto it = arrays.begin(); it != arrays.end(); ++it) {
2308       if (RE2::FullMatch(it->first, name_regexp)) {
2309         matches.insert(it->first);
2310       }
2311     }
2312   }
2313   return matches;
2314 }
2315 
UseArraysExtraInfo(Model * model,bool quantize_output)2316 void UseArraysExtraInfo(Model* model, bool quantize_output) {
2317   for (const auto& entry : model->flags.arrays_extra_info().entries()) {
2318     const auto matches = ScanArrayNames(*model, entry);
2319     for (const auto& matched_name : matches) {
2320       auto& array = model->GetArray(matched_name);
2321       if (entry.has_min() || entry.has_max()) {
2322         CHECK_EQ(entry.has_min(), entry.has_max());
2323         auto& minmax = array.GetOrCreateMinMax();
2324         minmax.min = entry.min();
2325         minmax.max = entry.max();
2326       }
2327       if (entry.has_data_type() && quantize_output) {
2328         array.final_data_type =
2329             ConvertIODataTypeToArrayDataType(entry.data_type());
2330       }
2331       if (entry.has_shape()) {
2332         array.clear_shape();
2333         // Make sure to create the shape even if there are no dims, to
2334         // correctly record 0-D shapes.
2335         array.mutable_shape();
2336         for (const auto& dim : entry.shape().dims()) {
2337           array.mutable_shape()->mutable_dims()->push_back(dim);
2338         }
2339       }
2340       if (entry.has_constant_float_value()) {
2341         CHECK(array.has_shape());
2342         if (array.data_type == ArrayDataType::kFloat) {
2343           auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
2344           data.resize(RequiredBufferSizeForShape(array.shape()));
2345           for (float& f : data) {
2346             f = entry.constant_float_value();
2347           }
2348         }
2349       }
2350     }
2351   }
2352 }
2353 
UndoWeightsShuffling(Model * model)2354 void UndoWeightsShuffling(Model* model) {
2355   for (const auto& op : model->operators) {
2356     if (op->type != toco::OperatorType::kFullyConnected) {
2357       continue;
2358     }
2359     const auto& fc_op = static_cast<toco::FullyConnectedOperator&>(*op);
2360     if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) {
2361       continue;
2362     }
2363     const string& weights_name = fc_op.inputs[1];
2364     QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1);
2365     auto& weights_array = model->GetArray(weights_name);
2366     QCHECK(weights_array.data_type == ArrayDataType::kUint8);
2367     auto& weights_data =
2368         weights_array.GetMutableBuffer<toco::ArrayDataType::kUint8>().data;
2369     const auto& weights_shape = weights_array.shape();
2370     QCHECK_EQ(weights_shape.dimensions_count(), 2);
2371     const int rows = weights_shape.dims(0);
2372     const int cols = weights_shape.dims(1);
2373     QCHECK_EQ(rows % 4, 0);
2374     QCHECK_EQ(cols % 16, 0);
2375     CHECK_EQ(rows * cols, weights_data.size());
2376     // Compute the de-shuffled weights
2377     std::vector<uint8> deshuffled_data(weights_data.size());
2378     uint8* shuffled_data_ptr = weights_data.data();
2379     for (int r = 0; r < rows; r += 4) {
2380       for (int c = 0; c < cols; c += 16) {
2381         for (int i = 0; i < 4; i++) {
2382           uint8* deshuffled_data_ptr =
2383               deshuffled_data.data() + (r + i) * cols + c;
2384           for (int j = 0; j < 16; j++) {
2385             uint8 shuffled_val = *shuffled_data_ptr++;
2386             // Deshuffling isn't only about deshuffling the storage layout,
2387             // it's also about undoing the flipping of the sign bit, which is
2388             // performed on the shuffled weights.
2389             uint8 deshuffled_val = shuffled_val ^ 0x80;
2390             *deshuffled_data_ptr++ = deshuffled_val;
2391           }
2392         }
2393       }
2394     }
2395     CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols);
2396     // Switch this FC op to using the deshuffled weights.
2397     weights_data = std::move(deshuffled_data);
2398   }
2399 }
2400 
CopyMinMaxAndQuantizationRelatedFields(const Array & src,Array * dst)2401 void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) {
2402   if (src.minmax) {
2403     dst->GetOrCreateMinMax() = src.GetMinMax();
2404   }
2405   if (src.quantization_params) {
2406     dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams();
2407   }
2408   dst->narrow_range = src.narrow_range;
2409 }
2410 
2411 }  // namespace toco
2412