1 /**
2  * Copyright 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "run_tflite.h"
18 
19 #include "tensorflow/contrib/lite/kernels/register.h"
20 
21 #include <android/log.h>
22 #include <cstdio>
23 #include <sys/time.h>
24 
25 #define LOG_TAG "NN_BENCHMARK"
26 
BenchmarkModel(const char * modelfile)27 BenchmarkModel::BenchmarkModel(const char* modelfile) {
28     // Memory map the model. NOTE this needs lifetime greater than or equal
29     // to interpreter context.
30     mTfliteModel = tflite::FlatBufferModel::BuildFromFile(modelfile);
31     if (!mTfliteModel) {
32         __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
33                             "Failed to load model %s", modelfile);
34         return;
35     }
36 
37     tflite::ops::builtin::BuiltinOpResolver resolver;
38     tflite::InterpreterBuilder(*mTfliteModel, resolver)(&mTfliteInterpreter);
39     if (!mTfliteInterpreter) {
40         __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
41                             "Failed to create TFlite interpreter");
42         return;
43     }
44 }
45 
~BenchmarkModel()46 BenchmarkModel::~BenchmarkModel() {
47 }
48 
setInput(const uint8_t * dataPtr,size_t length)49 bool BenchmarkModel::setInput(const uint8_t* dataPtr, size_t length) {
50     int input = mTfliteInterpreter->inputs()[0];
51     auto* input_tensor = mTfliteInterpreter->tensor(input);
52     switch (input_tensor->type) {
53         case kTfLiteFloat32:
54         case kTfLiteUInt8: {
55             void* raw = mTfliteInterpreter->typed_tensor<void>(input);
56             memcpy(raw, dataPtr, length);
57             break;
58         }
59         default:
60             __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
61                                 "Input tensor type not supported");
62             return false;
63     }
64     return true;
65 }
66 
resizeInputTensors(std::vector<int> shape)67 bool BenchmarkModel::resizeInputTensors(std::vector<int> shape) {
68     // The benchmark only expects single input tensor, hardcoded as 0.
69     int input = mTfliteInterpreter->inputs()[0];
70     mTfliteInterpreter->ResizeInputTensor(input, shape);
71     if (mTfliteInterpreter->AllocateTensors() != kTfLiteOk) {
72         __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Failed to allocate tensors!");
73         return false;
74     }
75     return true;
76 }
77 
runBenchmark(int num_inferences,bool use_nnapi)78 bool BenchmarkModel::runBenchmark(int num_inferences,
79                                   bool use_nnapi) {
80     mTfliteInterpreter->UseNNAPI(use_nnapi);
81 
82     for(int i = 0; i  < num_inferences; i++){
83         auto status = mTfliteInterpreter->Invoke();
84         if (status != kTfLiteOk) {
85             __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Failed to invoke: %d!", (int)status);
86             return false;
87         }
88     }
89     return true;
90 }