1 /**
2 * Copyright 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "run_tflite.h"
18
19 #include "tensorflow/contrib/lite/kernels/register.h"
20
21 #include <android/log.h>
22 #include <cstdio>
23 #include <sys/time.h>
24
25 #define LOG_TAG "NN_BENCHMARK"
26
BenchmarkModel(const char * modelfile)27 BenchmarkModel::BenchmarkModel(const char* modelfile) {
28 // Memory map the model. NOTE this needs lifetime greater than or equal
29 // to interpreter context.
30 mTfliteModel = tflite::FlatBufferModel::BuildFromFile(modelfile);
31 if (!mTfliteModel) {
32 __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
33 "Failed to load model %s", modelfile);
34 return;
35 }
36
37 tflite::ops::builtin::BuiltinOpResolver resolver;
38 tflite::InterpreterBuilder(*mTfliteModel, resolver)(&mTfliteInterpreter);
39 if (!mTfliteInterpreter) {
40 __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
41 "Failed to create TFlite interpreter");
42 return;
43 }
44 }
45
~BenchmarkModel()46 BenchmarkModel::~BenchmarkModel() {
47 }
48
setInput(const uint8_t * dataPtr,size_t length)49 bool BenchmarkModel::setInput(const uint8_t* dataPtr, size_t length) {
50 int input = mTfliteInterpreter->inputs()[0];
51 auto* input_tensor = mTfliteInterpreter->tensor(input);
52 switch (input_tensor->type) {
53 case kTfLiteFloat32:
54 case kTfLiteUInt8: {
55 void* raw = mTfliteInterpreter->typed_tensor<void>(input);
56 memcpy(raw, dataPtr, length);
57 break;
58 }
59 default:
60 __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
61 "Input tensor type not supported");
62 return false;
63 }
64 return true;
65 }
66
resizeInputTensors(std::vector<int> shape)67 bool BenchmarkModel::resizeInputTensors(std::vector<int> shape) {
68 // The benchmark only expects single input tensor, hardcoded as 0.
69 int input = mTfliteInterpreter->inputs()[0];
70 mTfliteInterpreter->ResizeInputTensor(input, shape);
71 if (mTfliteInterpreter->AllocateTensors() != kTfLiteOk) {
72 __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Failed to allocate tensors!");
73 return false;
74 }
75 return true;
76 }
77
runBenchmark(int num_inferences,bool use_nnapi)78 bool BenchmarkModel::runBenchmark(int num_inferences,
79 bool use_nnapi) {
80 mTfliteInterpreter->UseNNAPI(use_nnapi);
81
82 for(int i = 0; i < num_inferences; i++){
83 auto status = mTfliteInterpreter->Invoke();
84 if (status != kTfLiteOk) {
85 __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Failed to invoke: %d!", (int)status);
86 return false;
87 }
88 }
89 return true;
90 }