1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
16 #define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
17 
18 #include <cstdio>
19 #include <cstring>
20 
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/core/api/error_reporter.h"
23 #include "tensorflow/lite/core/api/flatbuffer_conversions.h"
24 #include "tensorflow/lite/kernels/internal/compatibility.h"
25 #include "tensorflow/lite/kernels/op_macros.h"
26 #include "tensorflow/lite/micro/compatibility.h"
27 #include "tensorflow/lite/micro/kernels/ethosu.h"
28 #include "tensorflow/lite/micro/kernels/fully_connected.h"
29 #include "tensorflow/lite/micro/kernels/micro_ops.h"
30 #include "tensorflow/lite/micro/micro_op_resolver.h"
31 #include "tensorflow/lite/schema/schema_generated.h"
32 
33 namespace tflite {
34 TfLiteRegistration* Register_DETECTION_POSTPROCESS();
35 
36 template <unsigned int tOpCount>
37 class MicroMutableOpResolver : public MicroOpResolver {
38  public:
39   TF_LITE_REMOVE_VIRTUAL_DELETE
40 
41   explicit MicroMutableOpResolver(ErrorReporter* error_reporter = nullptr)
error_reporter_(error_reporter)42       : error_reporter_(error_reporter) {}
43 
FindOp(tflite::BuiltinOperator op)44   const TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override {
45     if (op == BuiltinOperator_CUSTOM) return nullptr;
46 
47     for (unsigned int i = 0; i < registrations_len_; ++i) {
48       const TfLiteRegistration& registration = registrations_[i];
49       if (registration.builtin_code == op) {
50         return &registration;
51       }
52     }
53     return nullptr;
54   }
55 
FindOp(const char * op)56   const TfLiteRegistration* FindOp(const char* op) const override {
57     for (unsigned int i = 0; i < registrations_len_; ++i) {
58       const TfLiteRegistration& registration = registrations_[i];
59       if ((registration.builtin_code == BuiltinOperator_CUSTOM) &&
60           (strcmp(registration.custom_name, op) == 0)) {
61         return &registration;
62       }
63     }
64     return nullptr;
65   }
66 
GetOpDataParser(BuiltinOperator op)67   MicroOpResolver::BuiltinParseFunction GetOpDataParser(
68       BuiltinOperator op) const override {
69     TFLITE_DCHECK(num_buitin_ops_ <= tOpCount);
70     for (unsigned int i = 0; i < num_buitin_ops_; ++i) {
71       if (builtin_codes_[i] == op) return builtin_parsers_[i];
72     }
73     return nullptr;
74   }
75 
76   // Registers a Custom Operator with the MicroOpResolver.
77   //
78   // Only the first call for a given name will be successful. i.e. if this
79   // function is called again for a previously added Custom Operator, the
80   // MicroOpResolver will be unchanged and this function will return
81   // kTfLiteError.
AddCustom(const char * name,TfLiteRegistration * registration)82   TfLiteStatus AddCustom(const char* name, TfLiteRegistration* registration) {
83     if (registrations_len_ >= tOpCount) {
84       if (error_reporter_) {
85         TF_LITE_REPORT_ERROR(
86             error_reporter_,
87             "Couldn't register custom op '%s', resolver size is too small (%d)",
88             name, tOpCount);
89       }
90       return kTfLiteError;
91     }
92 
93     if (FindOp(name) != nullptr) {
94       if (error_reporter_ != nullptr) {
95         TF_LITE_REPORT_ERROR(error_reporter_,
96                              "Calling AddCustom for the same op more than once "
97                              "is not supported (Op: %s).",
98                              name);
99       }
100       return kTfLiteError;
101     }
102 
103     TfLiteRegistration* new_registration = &registrations_[registrations_len_];
104     registrations_len_ += 1;
105 
106     *new_registration = *registration;
107     new_registration->builtin_code = BuiltinOperator_CUSTOM;
108     new_registration->custom_name = name;
109     return kTfLiteOk;
110   }
111 
112   // The Add* functions below add the various Builtin operators to the
113   // MicroMutableOpResolver object.
114 
AddAbs()115   TfLiteStatus AddAbs() {
116     return AddBuiltin(BuiltinOperator_ABS, tflite::ops::micro::Register_ABS(),
117                       ParseAbs);
118   }
119 
AddAdd()120   TfLiteStatus AddAdd() {
121     return AddBuiltin(BuiltinOperator_ADD, tflite::ops::micro::Register_ADD(),
122                       ParseAdd);
123   }
124 
AddArgMax()125   TfLiteStatus AddArgMax() {
126     return AddBuiltin(BuiltinOperator_ARG_MAX,
127                       tflite::ops::micro::Register_ARG_MAX(), ParseArgMax);
128   }
129 
AddArgMin()130   TfLiteStatus AddArgMin() {
131     return AddBuiltin(BuiltinOperator_ARG_MIN,
132                       tflite::ops::micro::Register_ARG_MIN(), ParseArgMin);
133   }
134 
AddAveragePool2D()135   TfLiteStatus AddAveragePool2D() {
136     return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D,
137                       tflite::ops::micro::Register_AVERAGE_POOL_2D(),
138                       ParsePool);
139   }
140 
AddBatchToSpaceND()141   TfLiteStatus AddBatchToSpaceND() {
142     return AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND,
143                       Register_BATCH_TO_SPACE_ND(), ParseBatchToSpaceNd);
144   }
145 
AddCast()146   TfLiteStatus AddCast() {
147     return AddBuiltin(BuiltinOperator_CAST, Register_CAST(), ParseCast);
148   }
149 
AddCeil()150   TfLiteStatus AddCeil() {
151     return AddBuiltin(BuiltinOperator_CEIL, tflite::ops::micro::Register_CEIL(),
152                       ParseCeil);
153   }
154 
AddCircularBuffer()155   TfLiteStatus AddCircularBuffer() {
156     return AddCustom("CIRCULAR_BUFFER",
157                      tflite::ops::micro::Register_CIRCULAR_BUFFER());
158   }
159 
AddConcatenation()160   TfLiteStatus AddConcatenation() {
161     return AddBuiltin(BuiltinOperator_CONCATENATION,
162                       tflite::ops::micro::Register_CONCATENATION(),
163                       ParseConcatenation);
164   }
165 
AddConv2D()166   TfLiteStatus AddConv2D() {
167     return AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), ParseConv2D);
168   }
169 
AddCos()170   TfLiteStatus AddCos() {
171     return AddBuiltin(BuiltinOperator_COS, tflite::ops::micro::Register_COS(),
172                       ParseCos);
173   }
174 
AddDepthwiseConv2D()175   TfLiteStatus AddDepthwiseConv2D() {
176     return AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D,
177                       Register_DEPTHWISE_CONV_2D(), ParseDepthwiseConv2D);
178   }
179 
AddDequantize()180   TfLiteStatus AddDequantize() {
181     return AddBuiltin(BuiltinOperator_DEQUANTIZE,
182                       tflite::ops::micro::Register_DEQUANTIZE(),
183                       ParseDequantize);
184   }
185 
AddDetectionPostprocess()186   TfLiteStatus AddDetectionPostprocess() {
187     return AddCustom("TFLite_Detection_PostProcess",
188                      tflite::Register_DETECTION_POSTPROCESS());
189   }
190 
AddEqual()191   TfLiteStatus AddEqual() {
192     return AddBuiltin(BuiltinOperator_EQUAL,
193                       tflite::ops::micro::Register_EQUAL(), ParseEqual);
194   }
195 
AddEthosU()196   TfLiteStatus AddEthosU() {
197     TfLiteRegistration* registration = tflite::Register_ETHOSU();
198     if (registration) {
199       return AddCustom(tflite::GetString_ETHOSU(), registration);
200     }
201     return kTfLiteOk;
202   }
203 
AddExp()204   TfLiteStatus AddExp() {
205     return AddBuiltin(BuiltinOperator_EXP, Register_EXP(), ParseExp);
206   }
207 
AddFloor()208   TfLiteStatus AddFloor() {
209     return AddBuiltin(BuiltinOperator_FLOOR,
210                       tflite::ops::micro::Register_FLOOR(), ParseFloor);
211   }
212 
213   TfLiteStatus AddFullyConnected(
214       const TfLiteRegistration& registration = Register_FULLY_CONNECTED()) {
215     return AddBuiltin(BuiltinOperator_FULLY_CONNECTED, registration,
216                       ParseFullyConnected);
217   }
218 
AddGreater()219   TfLiteStatus AddGreater() {
220     return AddBuiltin(BuiltinOperator_GREATER,
221                       tflite::ops::micro::Register_GREATER(), ParseGreater);
222   }
223 
AddGreaterEqual()224   TfLiteStatus AddGreaterEqual() {
225     return AddBuiltin(BuiltinOperator_GREATER_EQUAL,
226                       tflite::ops::micro::Register_GREATER_EQUAL(),
227                       ParseGreaterEqual);
228   }
229 
AddHardSwish()230   TfLiteStatus AddHardSwish() {
231     return AddBuiltin(BuiltinOperator_HARD_SWISH,
232                       tflite::ops::micro::Register_HARD_SWISH(),
233                       ParseHardSwish);
234   }
235 
AddL2Normalization()236   TfLiteStatus AddL2Normalization() {
237     return AddBuiltin(BuiltinOperator_L2_NORMALIZATION,
238                       tflite::ops::micro::Register_L2_NORMALIZATION(),
239                       ParseL2Normalization);
240   }
241 
AddLess()242   TfLiteStatus AddLess() {
243     return AddBuiltin(BuiltinOperator_LESS, tflite::ops::micro::Register_LESS(),
244                       ParseLess);
245   }
246 
AddLessEqual()247   TfLiteStatus AddLessEqual() {
248     return AddBuiltin(BuiltinOperator_LESS_EQUAL,
249                       tflite::ops::micro::Register_LESS_EQUAL(),
250                       ParseLessEqual);
251   }
252 
AddLog()253   TfLiteStatus AddLog() {
254     return AddBuiltin(BuiltinOperator_LOG, tflite::ops::micro::Register_LOG(),
255                       ParseLog);
256   }
257 
AddLogicalAnd()258   TfLiteStatus AddLogicalAnd() {
259     return AddBuiltin(BuiltinOperator_LOGICAL_AND,
260                       tflite::ops::micro::Register_LOGICAL_AND(),
261                       ParseLogicalAnd);
262   }
263 
AddLogicalNot()264   TfLiteStatus AddLogicalNot() {
265     return AddBuiltin(BuiltinOperator_LOGICAL_NOT,
266                       tflite::ops::micro::Register_LOGICAL_NOT(),
267                       ParseLogicalNot);
268   }
269 
AddLogicalOr()270   TfLiteStatus AddLogicalOr() {
271     return AddBuiltin(BuiltinOperator_LOGICAL_OR,
272                       tflite::ops::micro::Register_LOGICAL_OR(),
273                       ParseLogicalOr);
274   }
275 
AddLogistic()276   TfLiteStatus AddLogistic() {
277     return AddBuiltin(BuiltinOperator_LOGISTIC,
278                       tflite::ops::micro::Register_LOGISTIC(), ParseLogistic);
279   }
280 
AddMaximum()281   TfLiteStatus AddMaximum() {
282     return AddBuiltin(BuiltinOperator_MAXIMUM,
283                       tflite::ops::micro::Register_MAXIMUM(), ParseMaximum);
284   }
285 
AddMaxPool2D()286   TfLiteStatus AddMaxPool2D() {
287     return AddBuiltin(BuiltinOperator_MAX_POOL_2D,
288                       tflite::ops::micro::Register_MAX_POOL_2D(), ParsePool);
289   }
290 
AddMean()291   TfLiteStatus AddMean() {
292     return AddBuiltin(BuiltinOperator_MEAN, tflite::ops::micro::Register_MEAN(),
293                       ParseReducer);
294   }
295 
AddMinimum()296   TfLiteStatus AddMinimum() {
297     return AddBuiltin(BuiltinOperator_MINIMUM,
298                       tflite::ops::micro::Register_MINIMUM(), ParseMinimum);
299   }
300 
AddMul()301   TfLiteStatus AddMul() {
302     return AddBuiltin(BuiltinOperator_MUL, tflite::ops::micro::Register_MUL(),
303                       ParseMul);
304   }
305 
AddNeg()306   TfLiteStatus AddNeg() {
307     return AddBuiltin(BuiltinOperator_NEG, tflite::ops::micro::Register_NEG(),
308                       ParseNeg);
309   }
310 
AddNotEqual()311   TfLiteStatus AddNotEqual() {
312     return AddBuiltin(BuiltinOperator_NOT_EQUAL,
313                       tflite::ops::micro::Register_NOT_EQUAL(), ParseNotEqual);
314   }
315 
AddPack()316   TfLiteStatus AddPack() {
317     return AddBuiltin(BuiltinOperator_PACK, tflite::ops::micro::Register_PACK(),
318                       ParsePack);
319   }
320 
AddPad()321   TfLiteStatus AddPad() {
322     return AddBuiltin(BuiltinOperator_PAD, tflite::ops::micro::Register_PAD(),
323                       ParsePad);
324   }
325 
AddPadV2()326   TfLiteStatus AddPadV2() {
327     return AddBuiltin(BuiltinOperator_PADV2,
328                       tflite::ops::micro::Register_PADV2(), ParsePadV2);
329   }
330 
AddPrelu()331   TfLiteStatus AddPrelu() {
332     return AddBuiltin(BuiltinOperator_PRELU,
333                       tflite::ops::micro::Register_PRELU(), ParsePrelu);
334   }
335 
AddQuantize()336   TfLiteStatus AddQuantize() {
337     return AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE(),
338                       ParseQuantize);
339   }
340 
AddReduceMax()341   TfLiteStatus AddReduceMax() {
342     return AddBuiltin(BuiltinOperator_REDUCE_MAX,
343                       tflite::ops::micro::Register_REDUCE_MAX(), ParseReducer);
344   }
345 
AddRelu()346   TfLiteStatus AddRelu() {
347     return AddBuiltin(BuiltinOperator_RELU, tflite::ops::micro::Register_RELU(),
348                       ParseRelu);
349   }
350 
AddRelu6()351   TfLiteStatus AddRelu6() {
352     return AddBuiltin(BuiltinOperator_RELU6,
353                       tflite::ops::micro::Register_RELU6(), ParseRelu6);
354   }
355 
AddReshape()356   TfLiteStatus AddReshape() {
357     return AddBuiltin(BuiltinOperator_RESHAPE,
358                       tflite::ops::micro::Register_RESHAPE(), ParseReshape);
359   }
360 
AddResizeNearestNeighbor()361   TfLiteStatus AddResizeNearestNeighbor() {
362     return AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
363                       tflite::ops::micro::Register_RESIZE_NEAREST_NEIGHBOR(),
364                       ParseResizeNearestNeighbor);
365   }
366 
AddRound()367   TfLiteStatus AddRound() {
368     return AddBuiltin(BuiltinOperator_ROUND,
369                       tflite::ops::micro::Register_ROUND(), ParseRound);
370   }
371 
AddRsqrt()372   TfLiteStatus AddRsqrt() {
373     return AddBuiltin(BuiltinOperator_RSQRT,
374                       tflite::ops::micro::Register_RSQRT(), ParseRsqrt);
375   }
376 
AddShape()377   TfLiteStatus AddShape() {
378     return AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE(), ParseShape);
379   }
380 
AddSin()381   TfLiteStatus AddSin() {
382     return AddBuiltin(BuiltinOperator_SIN, tflite::ops::micro::Register_SIN(),
383                       ParseSin);
384   }
385 
AddSoftmax()386   TfLiteStatus AddSoftmax() {
387     return AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX(),
388                       ParseSoftmax);
389   }
390 
AddSpaceToBatchNd()391   TfLiteStatus AddSpaceToBatchNd() {
392     return AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND,
393                       Register_SPACE_TO_BATCH_ND(), ParseSpaceToBatchNd);
394   }
395 
AddSplit()396   TfLiteStatus AddSplit() {
397     return AddBuiltin(BuiltinOperator_SPLIT,
398                       tflite::ops::micro::Register_SPLIT(), ParseSplit);
399   }
400 
AddSplitV()401   TfLiteStatus AddSplitV() {
402     return AddBuiltin(BuiltinOperator_SPLIT_V,
403                       tflite::ops::micro::Register_SPLIT_V(), ParseSplitV);
404   }
405 
AddSqrt()406   TfLiteStatus AddSqrt() {
407     return AddBuiltin(BuiltinOperator_SQRT, tflite::ops::micro::Register_SQRT(),
408                       ParseSqrt);
409   }
410 
AddSquare()411   TfLiteStatus AddSquare() {
412     return AddBuiltin(BuiltinOperator_SQUARE,
413                       tflite::ops::micro::Register_SQUARE(), ParseSquare);
414   }
415 
AddStridedSlice()416   TfLiteStatus AddStridedSlice() {
417     return AddBuiltin(BuiltinOperator_STRIDED_SLICE,
418                       tflite::ops::micro::Register_STRIDED_SLICE(),
419                       ParseStridedSlice);
420   }
421 
AddSub()422   TfLiteStatus AddSub() {
423     return AddBuiltin(BuiltinOperator_SUB, tflite::ops::micro::Register_SUB(),
424                       ParseSub);
425   }
426 
AddSvdf()427   TfLiteStatus AddSvdf() {
428     return AddBuiltin(BuiltinOperator_SVDF, Register_SVDF(), ParseSvdf);
429   }
430 
AddTanh()431   TfLiteStatus AddTanh() {
432     return AddBuiltin(BuiltinOperator_TANH, tflite::ops::micro::Register_TANH(),
433                       ParseTanh);
434   }
435 
AddUnpack()436   TfLiteStatus AddUnpack() {
437     return AddBuiltin(BuiltinOperator_UNPACK,
438                       tflite::ops::micro::Register_UNPACK(), ParseUnpack);
439   }
440 
AddZerosLike()441   TfLiteStatus AddZerosLike() {
442     return AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE(),
443                       ParseZerosLike);
444   }
445 
GetRegistrationLength()446   unsigned int GetRegistrationLength() { return registrations_len_; }
447 
448  private:
AddBuiltin(tflite::BuiltinOperator op,const TfLiteRegistration & registration,MicroOpResolver::BuiltinParseFunction parser)449   TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
450                           const TfLiteRegistration& registration,
451                           MicroOpResolver::BuiltinParseFunction parser) {
452     if (op == BuiltinOperator_CUSTOM) {
453       if (error_reporter_ != nullptr) {
454         TF_LITE_REPORT_ERROR(error_reporter_,
455                              "Invalid parameter BuiltinOperator_CUSTOM to the "
456                              "AddBuiltin function.");
457       }
458       return kTfLiteError;
459     }
460 
461     if (FindOp(op) != nullptr) {
462       if (error_reporter_ != nullptr) {
463         TF_LITE_REPORT_ERROR(error_reporter_,
464                              "Calling AddBuiltin with the same op more than "
465                              "once is not supported (Op: #%d).",
466                              op);
467       }
468       return kTfLiteError;
469     }
470 
471     if (registrations_len_ >= tOpCount) {
472       if (error_reporter_) {
473         TF_LITE_REPORT_ERROR(error_reporter_,
474                              "Couldn't register builtin op #%d, resolver size "
475                              "is too small (%d).",
476                              op, tOpCount);
477       }
478       return kTfLiteError;
479     }
480 
481     registrations_[registrations_len_] = registration;
482     // Strictly speaking, the builtin_code is not necessary for TFLM but filling
483     // it in regardless.
484     registrations_[registrations_len_].builtin_code = op;
485     registrations_len_++;
486 
487     builtin_codes_[num_buitin_ops_] = op;
488     builtin_parsers_[num_buitin_ops_] = parser;
489     num_buitin_ops_++;
490 
491     return kTfLiteOk;
492   }
493 
494   TfLiteRegistration registrations_[tOpCount];
495   unsigned int registrations_len_ = 0;
496 
497   // Arrays (and counter) to store the builtin codes and their corresponding
498   // parse functions as these are registered with the Op Resolver.
499   BuiltinOperator builtin_codes_[tOpCount];
500   MicroOpResolver::BuiltinParseFunction builtin_parsers_[tOpCount];
501   unsigned int num_buitin_ops_ = 0;
502 
503   ErrorReporter* error_reporter_;
504 };
505 
506 };  // namespace tflite
507 
508 #endif  // TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
509