1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/versioning/op_version.h"
16 
17 #include <algorithm>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 #include "tensorflow/lite/schema/schema_utils.h"
26 
27 namespace tflite {
28 namespace {
29 
30 static const auto kTensorTypeNone = static_cast<::tflite::TensorType>(-1);
31 
32 // Get the number of dimensions of a tensor with idx of an operator op.
GetNumDims(const SubGraph * subgraph,const Operator * op,int idx)33 inline int GetNumDims(const SubGraph* subgraph, const Operator* op, int idx) {
34   return subgraph->tensors()->Get(op->inputs()->Get(idx))->shape()->size();
35 }
36 
37 // Compare shape of two tensors with idx1 and idx2 of an operator op, return
38 // true if they have the same shape.
HaveSameShapes(const SubGraph * subgraph,const Operator * op,int idx1,int idx2)39 inline bool HaveSameShapes(const SubGraph* subgraph, const Operator* op,
40                            int idx1, int idx2) {
41   const flatbuffers::Vector<int32_t>* shape1 =
42       subgraph->tensors()->Get(op->inputs()->Get(idx1))->shape();
43   const flatbuffers::Vector<int32_t>* shape2 =
44       subgraph->tensors()->Get(op->inputs()->Get(idx2))->shape();
45   if (shape1->size() != shape2->size()) {
46     return false;
47   }
48   return std::equal(shape1->begin(), shape1->end(), shape2->begin());
49 }
50 }  // namespace
51 
GetBuiltinOperatorVersion(const OpSignature & op_sig)52 int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
53   switch (op_sig.op) {
54     case BuiltinOperator_CONV_2D:
55       // If the op has signed int16 op_sig.inputs and op_sig.outputs, its
56       // version 4.
57       if (op_sig.input_types.at(0) == TensorType_INT16 &&
58           op_sig.input_types.at(1) == TensorType_INT16 &&
59           op_sig.output_types.at(1) == TensorType_INT16) {
60         return 4;
61       }
62 
63       // If the op has signed int8 op_sig.inputs and op_sig.outputs, its
64       // version 3.
65       if (op_sig.input_types.at(0) == TensorType_INT8 &&
66           op_sig.input_types.at(1) == TensorType_INT8 &&
67           op_sig.output_types.at(0) == TensorType_INT8) {
68         return 3;
69       }
70       // If the op is a signed int8 hybrid operation, we need to return
71       // version 2 or 5 if per channel.
72       if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
73           op_sig.input_types.at(1) == TensorType_INT8 &&
74           op_sig.output_types.at(0) == TensorType_FLOAT32) {
75         if (op_sig.options.conv_2d.is_per_channel_quantized) {
76           return 5;
77         }
78         return 2;
79       }
80       return 1;
81 
82     case BuiltinOperator_DEPTHWISE_CONV_2D:
83       // If the op accepts int16, we return version 5.
84       if (op_sig.input_types.at(0) == TensorType_INT16 &&
85           op_sig.input_types.at(1) == TensorType_INT16 &&
86           op_sig.output_types.at(1) == TensorType_INT16) {
87         return 5;
88       }
89 
90       // If the op is a signed int8 hybrid operation, we need to return
91       // version 4 or 6 if per-channel.
92       if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
93           op_sig.input_types.at(1) == TensorType_INT8 &&
94           op_sig.output_types.at(0) == TensorType_FLOAT32) {
95         if (op_sig.options.depthwise_conv_2d.is_per_channel_quantized) {
96           return 6;
97         }
98         return 4;
99       }
100       // If the op has signed int8 op_sig.inputs and op_sig.outputs, its
101       // version 3.
102       if (op_sig.input_types.at(0) == TensorType_INT8 &&
103           op_sig.input_types.at(1) == TensorType_INT8 &&
104           op_sig.output_types.at(0) == TensorType_INT8) {
105         return 3;
106       }
107       if (op_sig.options.depthwise_conv_2d.dilation_w_factor != 1 ||
108           op_sig.options.depthwise_conv_2d.dilation_h_factor != 1) {
109         return 2;
110       }
111       return 1;
112 
113     case BuiltinOperator_FAKE_QUANT:
114       if (op_sig.options.fakequant.narrow_range) {
115         return 2;
116       }
117       return 1;
118 
119     case BuiltinOperator_FULLY_CONNECTED:
120       // +-----------------+--------------------+--------------------------+
121       // |                 |    Weight::Default | Weight::Shuffled4x16Int8 |
122       // +-----------------+--------------------+--------------------------+
123       // | Float           |                  1 |                        2 |
124       // | Quantized Uint8 |                  1 |                        2 |
125       // | Hybrid          |                  3 |                        3 |
126       // | Quantized Int8  |                  4 |                        4 |
127       // +-----------------+--------------------+--------------------------+
128 
129       // FullyConnected with sparse weight is supported at version 8.
130       if (op_sig.options.fully_connected.sparse_weight) {
131         return 8;
132       }
133 
134       // Int16 fully fixed point kernel is at version 7.
135       if (op_sig.input_types.at(0) == TensorType_INT16 &&
136           op_sig.input_types.at(1) == TensorType_INT16 &&
137           op_sig.output_types.at(0) == TensorType_INT16) {
138         return 7;
139       }
140 
141       // 2 op_sig.inputs (no bias) use case is supported starting from
142       // version 6.
143       if (op_sig.input_types.size() == 2) {
144         return 6;
145       }
146       // `keep_num_dims` is supported at version 5.
147       if (op_sig.options.fully_connected.keep_num_dims) {
148         return 5;
149       }
150       // Int8 fully fixed point kernel is at version 4.
151       if (op_sig.input_types.at(0) == TensorType_INT8 &&
152           op_sig.input_types.at(1) == TensorType_INT8 &&
153           op_sig.output_types.at(0) == TensorType_INT8) {
154         return 4;
155       }
156       // If the op is a signed int8 hybrid operation, we need to return
157       // version 3.
158       if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
159           op_sig.input_types.at(1) == TensorType_INT8 &&
160           op_sig.output_types.at(0) == TensorType_FLOAT32) {
161         if (op_sig.options.fully_connected.asymmetric_quantize_inputs) {
162           // This is to use the updated quantization scheme.
163           return 9;
164         }
165         return 3;
166       }
167       // For float and uint8 fixed point kernels, if the weight is
168       // Shuffled4x16Int8, it is version 2.
169       if (op_sig.options.fully_connected.weights_format ==
170           FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8) {
171         return 2;
172       }
173       // Otherwise (weight is default), the version is 1.
174       return 1;
175 
176     case BuiltinOperator_GATHER:
177       if (op_sig.input_types.at(0) == TensorType_INT16) {
178         return 4;
179       }
180       // If the op takes bool input, it is version 3.
181       if (op_sig.input_types.at(0) == TensorType_BOOL) {
182         return 3;
183       }
184       if (op_sig.input_types.at(0) == TensorType_INT8) {
185         return 2;
186       }
187       return 1;
188 
189     case BuiltinOperator_SVDF:
190       // Fully integer SVDF has int8 as input and is of version 3.
191       if (op_sig.input_types.at(0) == TensorType_INT8) {
192         return 3;
193       }
194       // If the op is a signed int8 hybrid operation, we need to return
195       // version 2.
196       if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
197           op_sig.input_types.at(1) == TensorType_INT8 &&
198           op_sig.output_types.at(0) == TensorType_FLOAT32) {
199         // This is to use the updated quantization scheme
200         if (op_sig.options.input_quantization.asymmetric_quantize_inputs) {
201           return 4;
202         }
203         return 2;
204       }
205       return 1;
206 
207     case BuiltinOperator_MUL:
208       // Version 4 supports int16 inputs
209       if (op_sig.input_types.at(0) == TensorType_INT16) {
210         return 4;
211       }
212       // Version 3 supports have a rescale value greater than or equal to 1.
213       if (op_sig.options.mul.input1_scale != 0 &&
214           op_sig.options.mul.input2_scale != 0 &&
215           op_sig.options.mul.output_scale != 0 &&
216           (op_sig.options.mul.input1_scale * op_sig.options.mul.input2_scale /
217            op_sig.options.mul.output_scale) >= 1.0) {
218         return 3;
219       }
220       if (op_sig.input_types.at(0) == TensorType_INT8) {
221         return 2;
222       }
223       return 1;
224 
225     case BuiltinOperator_MAX_POOL_2D:
226     case BuiltinOperator_AVERAGE_POOL_2D:
227       if (op_sig.input_types.at(0) == TensorType_INT16 &&
228           op_sig.output_types.at(0) == TensorType_INT16) {
229         return 3;
230       }
231 
232       if (op_sig.input_types.at(0) == TensorType_INT8) {
233         return 2;
234       }
235       return 1;
236 
237     case BuiltinOperator_TRANSPOSE:
238       if (op_sig.input_types.at(0) == TensorType_INT16) {
239         return 5;
240       }
241       if (op_sig.options.single_input_op.num_dims > 4) {
242         return 4;
243       }
244       // If the op takes bool input, it is version 3.
245       if (op_sig.input_types.at(0) == TensorType_BOOL) {
246         return 3;
247       }
248       if (op_sig.input_types.at(0) == TensorType_INT8) {
249         return 2;
250       }
251       return 1;
252 
253     case BuiltinOperator_TRANSPOSE_CONV: {
254       if (op_sig.input_types.size() == 4 &&
255           op_sig.input_types.at(3) != kTensorTypeNone) {
256         return 3;
257       }
258       // If the op takes int8 input, it is version 2.
259       if (op_sig.input_types.at(1) == TensorType_INT8) {
260         return 2;
261       }
262       return 1;
263     }
264 
265     case BuiltinOperator_LSTM:
266       // If the input tensor is float and a weight is int8, this is a version
267       // 3 hybrid operation.
268       if (op_sig.options.lstm.kernel_type == LSTMKernelType_FULL &&
269           op_sig.input_types.at(0) == TensorType_FLOAT32 &&
270           op_sig.input_types.at(2) == TensorType_INT8 &&
271           op_sig.output_types.at(0) == TensorType_FLOAT32) {
272         if (op_sig.options.lstm.asymmetric_quantize_inputs) {
273           return 4;
274         }
275         return 3;
276       }
277       // KERNEL_BASIC was added in version 2.
278       if (op_sig.options.lstm.kernel_type == LSTMKernelType_BASIC) {
279         return 2;
280       }
281       return 1;
282 
283     case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
284       // If the input tensor is float and a weight is int8, this is a version
285       // 2 hybrid operation.
286       if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
287           op_sig.input_types.at(2) == TensorType_INT8 &&
288           op_sig.output_types.at(0) == TensorType_FLOAT32) {
289         if (op_sig.options.lstm.asymmetric_quantize_inputs) {
290           return 3;
291         }
292         return 2;
293       }
294       return 1;
295 
296     case BuiltinOperator_SPLIT:
297       // If the op take in16 input, it is version 4.
298       if (op_sig.input_types.at(1) == TensorType_INT16) {
299         return 4;
300       }
301       // If the op take int8 input, it is version 2, for int32 it's version 3.
302       // The input tensor is at index 1 not 0, 0 is the axis.
303       if (op_sig.input_types.at(1) == TensorType_INT32) {
304         return 3;
305       }
306       if (op_sig.input_types.at(1) == TensorType_INT8) {
307         return 2;
308       }
309       return 1;
310 
311     case BuiltinOperator_SPARSE_TO_DENSE:
312       // Version 3 supports Int8 and Uint8 type.
313       if (op_sig.input_types.at(2) == TensorType_INT8 ||
314           op_sig.input_types.at(2) == TensorType_UINT8) {
315         return 3;
316       }
317       // Version 2 supports Int64 value type.
318       if (op_sig.input_types.at(2) == TensorType_INT64) {
319         return 2;
320       }
321       return 1;
322 
323     case BuiltinOperator_SLICE:
324       if (op_sig.options.single_input_op.num_dims > 4) {
325         return 5;
326       }
327       if (op_sig.input_types.at(0) == TensorType_INT16) {
328         return 4;
329       }
330       // Version 3 supports string input types.
331       if (op_sig.input_types.at(0) == TensorType_STRING) {
332         return 3;
333       }
334       if (op_sig.input_types.at(0) == TensorType_INT8) {
335         return 2;
336       }
337       return 1;
338 
339     case BuiltinOperator_UNPACK:
340       // If the op take int8/uint8 input, it is version 2.
341       if (op_sig.input_types.at(0) == TensorType_INT8 ||
342           op_sig.input_types.at(0) == TensorType_UINT8) {
343         return 2;
344       }
345       // If the op take bool input, it is version 3.
346       if (op_sig.input_types.at(0) == TensorType_BOOL) {
347         return 3;
348       }
349       if (op_sig.input_types.at(0) == TensorType_INT16 &&
350           op_sig.output_types.at(0) == TensorType_INT16) {
351         return 4;
352       }
353       return 1;
354 
355     case BuiltinOperator_DEQUANTIZE:
356       // Version 3 supports signed int16 input types.
357       if (op_sig.input_types.at(0) == TensorType_INT16 ||
358           op_sig.input_types.at(0) == TensorType_FLOAT16) {
359         return 3;
360       }
361       if (op_sig.input_types.at(0) == TensorType_INT8) {
362         return 2;
363       }
364       return 1;
365 
366     case BuiltinOperator_FLOOR_DIV:
367       if (op_sig.input_types.at(0) == TensorType_FLOAT32) {
368         return 2;
369       }
370       return 1;
371 
372     case BuiltinOperator_L2_NORMALIZATION:
373       if (op_sig.output_types.at(0) == TensorType_INT8) {
374         return 2;
375       }
376       return 1;
377 
378     case BuiltinOperator_ABS:
379     case BuiltinOperator_RELU:
380       if (op_sig.input_types.at(0) == TensorType_INT16) {
381         return 3;
382       }
383       if (op_sig.input_types.at(0) == TensorType_INT8 ||
384           op_sig.input_types.at(0) == TensorType_UINT8) {
385         return 2;
386       }
387       return 1;
388 
389     case BuiltinOperator_STRIDED_SLICE:
390       if (op_sig.input_types.at(0) == TensorType_STRING) {
391         return 5;
392       }
393       if (op_sig.options.single_input_op.num_dims > 4) {
394         return 4;
395       }
396       // If the op takes bool input, it is version 3.
397       if (op_sig.input_types.at(0) == TensorType_BOOL) {
398         return 3;
399       }
400       if (op_sig.input_types.at(0) == TensorType_INT8) {
401         return 2;
402       }
403       return 1;
404     case BuiltinOperator_REVERSE_V2:
405       if (op_sig.input_types.at(0) == TensorType_INT8) {
406         return 3;
407       }
408       if (op_sig.input_types.at(0) == TensorType_BOOL) {
409         return 2;
410       }
411       return 1;
412     case BuiltinOperator_RESIZE_BILINEAR:
413       if (op_sig.input_types.at(0) == TensorType_INT16) {
414         return 4;
415       } else if (op_sig.options.resize.half_pixel_centers) {
416         return 3;
417       } else if (op_sig.input_types.at(0) == TensorType_INT8) {
418         return 2;
419       }
420       return 1;
421     case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
422       if (op_sig.input_types.at(0) == TensorType_INT16) {
423         return 4;
424       } else if (op_sig.options.resize.half_pixel_centers ||
425                  op_sig.options.resize.align_corners) {
426         return 3;
427       } else if (op_sig.input_types.at(0) == TensorType_INT8) {
428         return 2;
429       }
430       return 1;
431 
432     case BuiltinOperator_MAXIMUM:
433     case BuiltinOperator_MINIMUM:
434       if (op_sig.input_types.at(0) == TensorType_INT16 &&
435           op_sig.output_types.at(0) == TensorType_INT16) {
436         return 4;
437       }
438       if (op_sig.options.broadcast.need_broadcast &&
439           op_sig.options.broadcast.num_dims > 4) {
440         return 3;
441       }
442       if (op_sig.input_types.at(0) == TensorType_INT8) {
443         return 2;
444       }
445       return 1;
446 
447     case BuiltinOperator_PACK:
448       if (op_sig.input_types.at(0) == TensorType_INT8) {
449         return 2;
450       }
451 
452       if (op_sig.input_types.at(0) == TensorType_INT16 &&
453           op_sig.output_types.at(0) == TensorType_INT16) {
454         return 3;
455       }
456       return 1;
457 
458     case BuiltinOperator_TILE:
459       if (op_sig.input_types.at(0) == TensorType_STRING) {
460         return 2;
461       }
462       return 1;
463 
464     case BuiltinOperator_SQUEEZE:
465       if (op_sig.input_types.at(0) == TensorType_STRING) {
466         return 2;
467       }
468       return 1;
469 
470     case BuiltinOperator_SPACE_TO_BATCH_ND:
471     case BuiltinOperator_BATCH_TO_SPACE_ND:
472       if (op_sig.options.single_input_op.num_dims != 4) {
473         return 3;
474       }
475       if (op_sig.input_types.at(0) == TensorType_INT8) {
476         return 2;
477       }
478       return 1;
479 
480     case BuiltinOperator_ADD:
481       if (op_sig.input_types.at(0) == TensorType_INT16 &&
482           op_sig.output_types.at(0) == TensorType_INT16) {
483         if (!op_sig.options.addsub.pot_scale_int16) {
484           return 3;
485         }
486       }
487       if (op_sig.input_types.at(0) == TensorType_INT8) {
488         return 2;
489       }
490       return 1;
491 
492     case BuiltinOperator_SUB:
493       if (op_sig.input_types.at(0) == TensorType_INT16 &&
494           op_sig.output_types.at(0) == TensorType_INT16) {
495         if (!op_sig.options.addsub.pot_scale_int16) {
496           return 5;
497         }
498       }
499       if (!op_sig.input_types.empty() &&
500           op_sig.input_types.at(0) == TensorType_INT64) {
501         return 4;
502       }
503       if (op_sig.options.addsub.need_broadcast &&
504           op_sig.options.addsub.num_dims > 4) {
505         return 3;
506       }
507       if (op_sig.input_types.at(0) == TensorType_INT8) {
508         return 2;
509       }
510       return 1;
511 
512     case BuiltinOperator_GATHER_ND:
513       if (!op_sig.input_types.empty() &&
514           (op_sig.input_types.at(0) == TensorType_INT16)) {
515         return 3;
516       }
517       if (!op_sig.input_types.empty() &&
518           op_sig.input_types.at(0) == TensorType_STRING) {
519         return 2;
520       }
521       return 1;
522 
523     case BuiltinOperator_DIV:
524       if (op_sig.options.broadcast.need_broadcast &&
525           op_sig.options.broadcast.num_dims > 4) {
526         return 2;
527       }
528       return 1;
529     case BuiltinOperator_TANH:
530     case BuiltinOperator_LOGISTIC:
531       if (op_sig.input_types.at(0) == TensorType_INT16 &&
532           op_sig.output_types.at(0) == TensorType_INT16) {
533         return 3;
534       }
535 
536       if (op_sig.input_types.at(0) == TensorType_INT8) {
537         return 2;
538       }
539       return 1;
540 
541     case BuiltinOperator_FILL:
542       if (op_sig.input_types.size() >= 2) {
543         if (op_sig.input_types.at(1) == TensorType_INT8 ||
544             op_sig.input_types.at(1) == TensorType_INT16) {
545           return 3;
546         } else if ((op_sig.input_types.at(1) == TensorType_BOOL ||
547                     op_sig.input_types.at(1) == TensorType_STRING)) {
548           return 2;
549         }
550       }
551       return 1;
552 
553     case BuiltinOperator_EQUAL:
554     case BuiltinOperator_NOT_EQUAL:
555       if (!op_sig.input_types.empty()) {
556         if (op_sig.input_types.at(0) == TensorType_STRING) {
557           return 3;
558         }
559         if (op_sig.input_types.at(0) == TensorType_INT8) {
560           return 2;
561         }
562       }
563       return 1;
564 
565     case BuiltinOperator_LEAKY_RELU:
566       if (op_sig.input_types.at(0) == TensorType_INT16) {
567         return 2;
568       }
569       return 1;
570 
571     case BuiltinOperator_BATCH_MATMUL:
572       // In case of int16 inputs, the version is 3.
573       if (op_sig.input_types.at(0) == TensorType_INT16) {
574         return 3;
575       }
576       if (op_sig.input_types.at(0) == TensorType_INT8) {
577         return 2;
578       }
579       if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
580           op_sig.input_types.at(1) == TensorType_INT8 &&
581           op_sig.output_types.at(0) == TensorType_FLOAT32) {
582         if (op_sig.options.input_quantization.asymmetric_quantize_inputs) {
583           // This is to use the updated quantization scheme.
584           return 4;
585         }
586       }
587       return 1;
588 
589     case BuiltinOperator_CONCATENATION:
590     case BuiltinOperator_SOFTMAX:
591     case BuiltinOperator_MEAN:
592     case BuiltinOperator_PAD:
593     case BuiltinOperator_PADV2:
594     case BuiltinOperator_REDUCE_MAX:
595     case BuiltinOperator_REDUCE_MIN:
596     case BuiltinOperator_RELU6:
597       // In case of int16 inputs, the version is 3.
598       if (op_sig.input_types.at(0) == TensorType_INT16) {
599         return 3;
600       }
601       if (op_sig.input_types.at(0) == TensorType_INT8) {
602         return 2;
603       }
604       return 1;
605 
606     case BuiltinOperator_RNN:
607     case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN:
608     case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
609     case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
610       if (op_sig.input_types.at(1) == TensorType_INT8 &&
611           op_sig.output_types.at(0) == TensorType_FLOAT32) {
612         if (op_sig.options.input_quantization.asymmetric_quantize_inputs) {
613           return 3;
614         } else {
615           return 2;
616         }
617       }
618       return 1;
619 
620     case BuiltinOperator_SPACE_TO_DEPTH:
621     case BuiltinOperator_SPLIT_V:
622     case BuiltinOperator_SUM:
623     case BuiltinOperator_LOG_SOFTMAX:
624     case BuiltinOperator_TOPK_V2:
625     case BuiltinOperator_ARG_MAX:
626     case BuiltinOperator_ARG_MIN:
627     case BuiltinOperator_GREATER:
628     case BuiltinOperator_GREATER_EQUAL:
629     case BuiltinOperator_LESS:
630     case BuiltinOperator_LESS_EQUAL:
631     case BuiltinOperator_SELECT:
632     case BuiltinOperator_RSQRT:
633     case BuiltinOperator_SQUARED_DIFFERENCE:
634     case BuiltinOperator_DEPTH_TO_SPACE:
635     case BuiltinOperator_MIRROR_PAD:
636       if (op_sig.input_types.at(0) == TensorType_INT8) {
637         return 2;
638       }
639       return 1;
640     // The version one of broadcast to op won't be not supported since the
641     // version one was rollbacked and the builtin op code number has been
642     // changed because of builtin op code shortage problem.
643     // Quantized broadcast_to is version 3
644     case BuiltinOperator_BROADCAST_TO:
645       if (op_sig.input_types.at(0) == TensorType_INT8 ||
646           op_sig.input_types.at(0) == TensorType_INT16) {
647         return 3;
648       }
649       return 2;
650     default:
651       return 1;
652   }
653 }
654 
GetTensorType(int32_t idx,const SubGraph * subgraph)655 TensorType GetTensorType(int32_t idx, const SubGraph* subgraph) {
656   if (idx == -1)
657     // For optional input/output, return none type directly.
658     return kTensorTypeNone;
659 
660   // Some tests have a graph with invalid tensor index.
661   TFLITE_DCHECK_GE(idx, 0);
662   if (subgraph->tensors() && idx < subgraph->tensors()->Length()) {
663     return subgraph->tensors()->Get(idx)->type();
664   }
665   LOG(ERROR) << "Can't access tenor " << idx;
666   return kTensorTypeNone;
667 }
668 
669 // Generate OpSignature with the given OperatorCode, Operator and Tensors (from
670 // SubGraph). The OpSignature will be used by GetBuiltinOperatorVersion() and
671 // mostly input and output tensor types are enough to figure out op version.
672 // But some ops (DEPTHWISE_CONV_2D,  FULLY_CONNECTED, ...) require to pass their
673 // options to decide op version.
GetOpSignature(const OperatorCode * op_code,const Operator * op,const SubGraph * subgraph)674 OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
675                            const SubGraph* subgraph) {
676   auto builtin_code = GetBuiltinCode(op_code);
677   OpSignature op_sig = {builtin_code};
678 
679   switch (builtin_code) {
680     case BuiltinOperator_DEPTHWISE_CONV_2D: {
681       auto conv_option = op->builtin_options_as_DepthwiseConv2DOptions();
682       if (conv_option) {
683         op_sig.options.depthwise_conv_2d.dilation_w_factor =
684             conv_option->dilation_w_factor();
685         op_sig.options.depthwise_conv_2d.dilation_h_factor =
686             conv_option->dilation_h_factor();
687       }
688       const Tensor* filter_tensor =
689           subgraph->tensors()->Get(op->inputs()->Get(1));
690       const QuantizationParameters* filter_quant =
691           filter_tensor->quantization();
692       int num_channels = filter_tensor->shape()->Get(3);
693       if (filter_quant && filter_quant->scale() &&
694           filter_quant->scale()->Length() &&
695           filter_quant->scale()->Length() == num_channels) {
696         op_sig.options.depthwise_conv_2d.is_per_channel_quantized = true;
697       }
698     } break;
699 
700     case BuiltinOperator_FAKE_QUANT: {
701       auto fakequant_option = op->builtin_options_as_FakeQuantOptions();
702       if (fakequant_option) {
703         op_sig.options.fakequant.narrow_range =
704             fakequant_option->narrow_range();
705       }
706     } break;
707 
708     case BuiltinOperator_FULLY_CONNECTED: {
709       auto fully_connected_option =
710           op->builtin_options_as_FullyConnectedOptions();
711       if (fully_connected_option) {
712         op_sig.options.fully_connected.keep_num_dims =
713             fully_connected_option->keep_num_dims();
714         op_sig.options.fully_connected.weights_format =
715             fully_connected_option->weights_format();
716         op_sig.options.fully_connected.asymmetric_quantize_inputs =
717             fully_connected_option->asymmetric_quantize_inputs();
718       }
719 
720       const Tensor* weight_tensor =
721           subgraph->tensors()->Get(op->inputs()->Get(1));
722       op_sig.options.fully_connected.sparse_weight =
723           (weight_tensor->sparsity() != nullptr);
724     } break;
725 
726     case BuiltinOperator_MUL: {
727       if (op->inputs()->Length() < 2 || op->outputs()->Length() < 1) {
728         break;
729       }
730       const Tensor* input1_tensor =
731           subgraph->tensors()->Get(op->inputs()->Get(0));
732       const Tensor* input2_tensor =
733           subgraph->tensors()->Get(op->inputs()->Get(1));
734       const Tensor* output_tensor =
735           subgraph->tensors()->Get(op->outputs()->Get(0));
736       const QuantizationParameters* input1_quant =
737           input1_tensor->quantization();
738       const QuantizationParameters* input2_qunt = input2_tensor->quantization();
739       const QuantizationParameters* output_quant =
740           output_tensor->quantization();
741       if (input1_quant && input1_quant->scale() &&
742           input1_quant->scale()->Length() && input2_qunt &&
743           input2_qunt->scale() && input2_qunt->scale()->Length() &&
744           output_quant && output_quant->scale() &&
745           output_quant->scale()->Length()) {
746         op_sig.options.mul.input1_scale = input1_quant->scale()->Get(0);
747         op_sig.options.mul.input2_scale = input2_qunt->scale()->Get(0);
748         op_sig.options.mul.output_scale = output_quant->scale()->Get(0);
749       }
750     } break;
751 
752     case BuiltinOperator_ADD: {
753       auto add_option = op->builtin_options_as_AddOptions();
754       op_sig.options.addsub.pot_scale_int16 = true;
755       if (add_option) {
756         op_sig.options.addsub.pot_scale_int16 = add_option->pot_scale_int16();
757       }
758     } break;
759 
760     case BuiltinOperator_SUB: {
761       auto sub_option = op->builtin_options_as_SubOptions();
762       op_sig.options.addsub.need_broadcast =
763           !HaveSameShapes(subgraph, op, 0, 1);
764       op_sig.options.addsub.num_dims =
765           std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1));
766       op_sig.options.addsub.pot_scale_int16 = true;
767       if (sub_option) {
768         op_sig.options.addsub.pot_scale_int16 = sub_option->pot_scale_int16();
769       }
770     } break;
771 
772     case BuiltinOperator_LSTM: {
773       auto lstm_option = op->builtin_options_as_LSTMOptions();
774       if (lstm_option) {
775         op_sig.options.lstm.kernel_type = lstm_option->kernel_type();
776       }
777     } break;
778 
779     case BuiltinOperator_RESIZE_BILINEAR: {
780       auto resize_bilinear_option =
781           op->builtin_options_as_ResizeBilinearOptions();
782       if (resize_bilinear_option) {
783         op_sig.options.resize.half_pixel_centers =
784             resize_bilinear_option->half_pixel_centers();
785         op_sig.options.resize.align_corners =
786             resize_bilinear_option->align_corners();
787       }
788     } break;
789     case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: {
790       auto resize_nn_option =
791           op->builtin_options_as_ResizeNearestNeighborOptions();
792       if (resize_nn_option) {
793         op_sig.options.resize.half_pixel_centers =
794             resize_nn_option->half_pixel_centers();
795         op_sig.options.resize.align_corners = resize_nn_option->align_corners();
796       }
797     } break;
798     case BuiltinOperator_CONV_2D: {
799       const Tensor* filter_tensor =
800           subgraph->tensors()->Get(op->inputs()->Get(1));
801       const QuantizationParameters* filter_quant =
802           filter_tensor->quantization();
803       int num_channels = filter_tensor->shape()->Get(0);
804       if (filter_quant && filter_quant->scale() &&
805           filter_quant->scale()->Length() &&
806           filter_quant->scale()->Length() == num_channels) {
807         op_sig.options.conv_2d.is_per_channel_quantized = true;
808       }
809     } break;
810     // TODO(b/150176627): Add tests for GetOpSignature.
811     case BuiltinOperator_STRIDED_SLICE:
812     case BuiltinOperator_SLICE:
813     case BuiltinOperator_SPACE_TO_BATCH_ND:
814     case BuiltinOperator_BATCH_TO_SPACE_ND:
815     case BuiltinOperator_TRANSPOSE: {
816       op_sig.options.single_input_op.num_dims = GetNumDims(subgraph, op, 0);
817     } break;
818 
819     case BuiltinOperator_DIV:
820     case BuiltinOperator_MAXIMUM:
821     case BuiltinOperator_MINIMUM: {
822       op_sig.options.broadcast.need_broadcast =
823           !HaveSameShapes(subgraph, op, 0, 1);
824       op_sig.options.broadcast.num_dims =
825           std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1));
826     } break;
827 
828     case BuiltinOperator_BATCH_MATMUL: {
829       auto batch_matmul_option = op->builtin_options_as_BatchMatMulOptions();
830       op_sig.options.input_quantization.asymmetric_quantize_inputs =
831           batch_matmul_option->asymmetric_quantize_inputs();
832     } break;
833 
834     default:
835       break;
836   }
837 
838   for (int32_t i = 0; i < op->inputs()->Length(); ++i) {
839     TensorType tensor_type = GetTensorType(op->inputs()->Get(i), subgraph);
840     op_sig.input_types.push_back(tensor_type);
841   }
842   for (int32_t i = 0; i < op->outputs()->Length(); ++i) {
843     TensorType tensor_type = GetTensorType(op->outputs()->Get(i), subgraph);
844     op_sig.output_types.push_back(tensor_type);
845   }
846   return op_sig;
847 }
848 
UpdateOpVersion(uint8_t * model_buffer_pointer)849 void UpdateOpVersion(uint8_t* model_buffer_pointer) {
850   auto model = GetMutableModel(model_buffer_pointer);
851   auto subgraphs = model->subgraphs();
852 
853   for (int i = 0; i < subgraphs->Length(); ++i) {
854     const SubGraph* subgraph = subgraphs->Get(i);
855     for (int j = 0; j < subgraph->operators()->Length(); ++j) {
856       const Operator* op = subgraph->operators()->Get(j);
857       OperatorCode* op_code =
858           model->mutable_operator_codes()->GetMutableObject(op->opcode_index());
859 
860       auto builtin_code = GetBuiltinCode(op_code);
861       if (builtin_code != BuiltinOperator_CUSTOM) {
862         OpSignature op_sig = GetOpSignature(op_code, op, subgraph);
863         // Update builtin operator version.
864         int32_t op_ver = GetBuiltinOperatorVersion(op_sig);
865         if (!op_code->mutate_version(op_ver)) {
866           LOG(ERROR) << "Can't set operator "
867                      << EnumNameBuiltinOperator(builtin_code) << " to version "
868                      << op_ver;
869         }
870       }
871     }
872   }
873 }
874 
875 }  // namespace tflite
876