1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "model.h"
18 
19 #include "sine_model_data.h"
20 #include "tensorflow/lite/micro/kernels/micro_ops.h"
21 #include "tensorflow/lite/micro/micro_error_reporter.h"
22 #include "tensorflow/lite/micro/micro_interpreter.h"
23 #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 
26 //  The following registration code is generated. Check the following commit for
27 //  details.
28 //  https://github.com/tensorflow/tensorflow/commit/098556c3a96e1d51f79606c0834547cb2aa20908
29 
30 namespace {
RegisterSelectedOps(::tflite::MicroMutableOpResolver * resolver)31 void RegisterSelectedOps(::tflite::MicroMutableOpResolver *resolver) {
32   resolver->AddBuiltin(
33       ::tflite::BuiltinOperator_FULLY_CONNECTED,
34       // For now the op version is not supported in the generated code, so this
35       // version still needs to added manually.
36       ::tflite::ops::micro::Register_FULLY_CONNECTED(), 1, 4);
37 }
38 }  // namespace
39 
40 namespace demo {
run(float x_val)41 float run(float x_val) {
42   tflite::MicroErrorReporter micro_error_reporter;
43   const tflite::Model *model = tflite::GetModel(g_sine_model_data);
44   // TODO(wangtz): Check for schema version.
45 
46   tflite::MicroMutableOpResolver resolver;
47   RegisterSelectedOps(&resolver);
48   constexpr int kTensorAreanaSize = 2 * 1024;
49   uint8_t tensor_arena[kTensorAreanaSize];
50 
51   tflite::MicroInterpreter interpreter(
52       model, resolver, tensor_arena, kTensorAreanaSize, &micro_error_reporter);
53   interpreter.AllocateTensors();
54 
55   TfLiteTensor *input = interpreter.input(0);
56   TfLiteTensor *output = interpreter.output(0);
57   input->data.f[0] = x_val;
58   TfLiteStatus invoke_status = interpreter.Invoke();
59   if (invoke_status != kTfLiteOk) {
60     micro_error_reporter.ReportError(nullptr, "Internal error: invoke failed.");
61     return 0.0;
62   }
63   float y_val = output->data.f[0];
64   return y_val;
65 }
66 
67 }  // namespace demo
68