1 /**
2  * Copyright 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "run_tflite.h"
18 
19 #include "tensorflow/lite/nnapi/nnapi_implementation.h"
20 
21 #include <jni.h>
22 #include <string>
23 #include <iomanip>
24 #include <sstream>
25 #include <fcntl.h>
26 
27 #include <android/asset_manager_jni.h>
28 #include <android/log.h>
29 #include <android/sharedmem.h>
30 #include <sys/mman.h>
31 
32 
33 extern "C"
34 JNIEXPORT jlong
35 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_initModel(JNIEnv * env,jobject,jstring _modelFileName,jboolean _useNnApi,jboolean _enableIntermediateTensorsDump,jstring _nnApiDeviceName)36 Java_com_android_nn_benchmark_core_NNTestBase_initModel(
37         JNIEnv *env,
38         jobject /* this */,
39         jstring _modelFileName,
40         jboolean _useNnApi,
41         jboolean _enableIntermediateTensorsDump,
42         jstring _nnApiDeviceName) {
43     const char *modelFileName = env->GetStringUTFChars(_modelFileName, NULL);
44     const char *nnApiDeviceName =
45         _nnApiDeviceName == NULL
46             ? NULL
47             : env->GetStringUTFChars(_nnApiDeviceName, NULL);
48     void *handle =
49         BenchmarkModel::create(modelFileName, _useNnApi,
50                                _enableIntermediateTensorsDump, nnApiDeviceName);
51     env->ReleaseStringUTFChars(_modelFileName, modelFileName);
52     if (_nnApiDeviceName != NULL) {
53         env->ReleaseStringUTFChars(_nnApiDeviceName, nnApiDeviceName);
54     }
55 
56     return (jlong)(uintptr_t)handle;
57 }
58 
59 extern "C"
60 JNIEXPORT void
61 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_destroyModel(JNIEnv * env,jobject,jlong _modelHandle)62 Java_com_android_nn_benchmark_core_NNTestBase_destroyModel(
63         JNIEnv *env,
64         jobject /* this */,
65         jlong _modelHandle) {
66     BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
67     delete(model);
68 }
69 
70 extern "C"
71 JNIEXPORT jboolean
72 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_resizeInputTensors(JNIEnv * env,jobject,jlong _modelHandle,jintArray _inputShape)73 Java_com_android_nn_benchmark_core_NNTestBase_resizeInputTensors(
74         JNIEnv *env,
75         jobject /* this */,
76         jlong _modelHandle,
77         jintArray _inputShape) {
78     BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
79     jint* shapePtr = env->GetIntArrayElements(_inputShape, nullptr);
80     jsize shapeLen = env->GetArrayLength(_inputShape);
81 
82     std::vector<int> shape(shapePtr, shapePtr + shapeLen);
83     return model->resizeInputTensors(std::move(shape));
84 }
85 
86 /** RAII container for a list of InferenceInOutSequence to handle JNI data release in destructor. */
87 class InferenceInOutSequenceList {
88 public:
89     InferenceInOutSequenceList(JNIEnv *env,
90                                const jobject& inOutDataList,
91                                bool expectGoldenOutputs);
92     ~InferenceInOutSequenceList();
93 
isValid() const94     bool isValid() const { return mValid; }
95 
data() const96     const std::vector<InferenceInOutSequence>& data() const { return mData; }
97 
98 private:
99     JNIEnv *mEnv;  // not owned.
100 
101     std::vector<InferenceInOutSequence> mData;
102     std::vector<jbyteArray> mInputArrays;
103     std::vector<jobjectArray> mOutputArrays;
104     bool mValid;
105 };
106 
InferenceInOutSequenceList(JNIEnv * env,const jobject & inOutDataList,bool expectGoldenOutputs)107 InferenceInOutSequenceList::InferenceInOutSequenceList(JNIEnv *env,
108                                                        const jobject& inOutDataList,
109                                                        bool expectGoldenOutputs)
110     : mEnv(env), mValid(false) {
111 
112     jclass list_class = env->FindClass("java/util/List");
113     if (list_class == nullptr) { return; }
114     jmethodID list_size = env->GetMethodID(list_class, "size", "()I");
115     if (list_size == nullptr) { return; }
116     jmethodID list_get = env->GetMethodID(list_class, "get", "(I)Ljava/lang/Object;");
117     if (list_get == nullptr) { return; }
118     jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
119     if (list_add == nullptr) { return; }
120 
121     jclass inOutSeq_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOutSequence");
122     if (inOutSeq_class == nullptr) { return; }
123     jmethodID inOutSeq_size = env->GetMethodID(inOutSeq_class, "size", "()I");
124     if (inOutSeq_size == nullptr) { return; }
125     jmethodID inOutSeq_get = env->GetMethodID(inOutSeq_class, "get",
126                                               "(I)Lcom/android/nn/benchmark/core/InferenceInOut;");
127     if (inOutSeq_get == nullptr) { return; }
128 
129     jclass inout_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut");
130     if (inout_class == nullptr) { return; }
131     jfieldID inout_input = env->GetFieldID(inout_class, "mInput", "[B");
132     if (inout_input == nullptr) { return; }
133     jfieldID inout_expectedOutputs = env->GetFieldID(inout_class, "mExpectedOutputs", "[[B");
134     if (inout_expectedOutputs == nullptr) { return; }
135     jfieldID inout_inputCreator = env->GetFieldID(inout_class, "mInputCreator",
136             "Lcom/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface;");
137     if (inout_inputCreator == nullptr) { return; }
138 
139 
140 
141     // Fetch input/output arrays
142     size_t data_count = mEnv->CallIntMethod(inOutDataList, list_size);
143     if (env->ExceptionCheck()) { return; }
144     mData.reserve(data_count);
145 
146     jclass inputCreator_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface");
147     if (inputCreator_class == nullptr) { return; }
148     jmethodID createInput_method = env->GetMethodID(inputCreator_class, "createInput", "(Ljava/nio/ByteBuffer;)V");
149     if (createInput_method == nullptr) { return; }
150 
151     for (int seq_index = 0; seq_index < data_count; ++seq_index) {
152         jobject inOutSeq = mEnv->CallObjectMethod(inOutDataList, list_get, seq_index);
153         if (mEnv->ExceptionCheck()) { return; }
154 
155         size_t seqLen = mEnv->CallIntMethod(inOutSeq, inOutSeq_size);
156         if (mEnv->ExceptionCheck()) { return; }
157 
158         mData.push_back(InferenceInOutSequence{});
159         auto& seq = mData.back();
160         seq.reserve(seqLen);
161         for (int i = 0; i < seqLen; ++i) {
162             jobject inout = mEnv->CallObjectMethod(inOutSeq, inOutSeq_get, i);
163             if (mEnv->ExceptionCheck()) { return; }
164 
165             uint8_t* input_data = nullptr;
166             size_t input_len = 0;
167             std::function<bool(uint8_t*, size_t)> inputCreator;
168             jbyteArray input = static_cast<jbyteArray>(
169                     mEnv->GetObjectField(inout, inout_input));
170             mInputArrays.push_back(input);
171             if (input != nullptr) {
172                 input_data = reinterpret_cast<uint8_t*>(
173                         mEnv->GetByteArrayElements(input, NULL));
174                 input_len = mEnv->GetArrayLength(input);
175             } else {
176                 inputCreator = [env, inout, inout_inputCreator, createInput_method](
177                         uint8_t* buffer, size_t length) {
178                     jobject byteBuffer = env->NewDirectByteBuffer(buffer, length);
179                     if (byteBuffer == nullptr) { return false; }
180                     jobject creator = env->GetObjectField(inout, inout_inputCreator);
181                     if (creator == nullptr) { return false; }
182                     env->CallVoidMethod(creator, createInput_method, byteBuffer);
183                     env->DeleteLocalRef(byteBuffer);
184                     if (env->ExceptionCheck()) { return false; }
185                     return true;
186                 };
187             }
188 
189             jobjectArray expectedOutputs = static_cast<jobjectArray>(
190                     mEnv->GetObjectField(inout, inout_expectedOutputs));
191             mOutputArrays.push_back(expectedOutputs);
192             seq.push_back({input_data, input_len, {}, inputCreator});
193 
194             // Add expected output to sequence added above
195             if (expectedOutputs != nullptr) {
196                 jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
197                 auto& outputs = seq.back().outputs;
198                 outputs.reserve(expectedOutputsLength);
199 
200                 for (jsize j = 0;j < expectedOutputsLength; ++j) {
201                     jbyteArray expectedOutput =
202                             static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
203                     if (env->ExceptionCheck()) {
204                         return;
205                     }
206                     if (expectedOutput == nullptr) {
207                         jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
208                         mEnv->ThrowNew(iaeClass, "Null expected output array");
209                         return;
210                     }
211 
212                     uint8_t *expectedOutput_data = reinterpret_cast<uint8_t*>(
213                                         mEnv->GetByteArrayElements(expectedOutput, NULL));
214                     size_t expectedOutput_len = mEnv->GetArrayLength(expectedOutput);
215                     outputs.push_back({ expectedOutput_data, expectedOutput_len});
216                 }
217             } else {
218                 if (expectGoldenOutputs) {
219                     jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
220                     mEnv->ThrowNew(iaeClass, "Expected golden output for every input");
221                     return;
222                 }
223             }
224         }
225     }
226     mValid = true;
227 }
228 
~InferenceInOutSequenceList()229 InferenceInOutSequenceList::~InferenceInOutSequenceList() {
230     // Note that we may land here with a pending JNI exception so cannot call
231     // java objects.
232     int arrayIndex = 0;
233     for (int seq_index = 0; seq_index < mData.size(); ++seq_index) {
234         for (int i = 0; i < mData[seq_index].size(); ++i) {
235             jbyteArray input = mInputArrays[arrayIndex];
236             if (input != nullptr) {
237                 mEnv->ReleaseByteArrayElements(
238                         input, reinterpret_cast<jbyte*>(mData[seq_index][i].input), JNI_ABORT);
239             }
240             jobjectArray expectedOutputs = mOutputArrays[arrayIndex];
241             if (expectedOutputs != nullptr) {
242                 jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
243                 if (expectedOutputsLength != mData[seq_index][i].outputs.size()) {
244                     // Should not happen? :)
245                     jclass iaeClass = mEnv->FindClass("java/lang/IllegalStateException");
246                     mEnv->ThrowNew(iaeClass, "Mismatch of the size of expected outputs jni array "
247                                    "and internal array of its bufers");
248                     return;
249                 }
250 
251                 for (jsize j = 0;j < expectedOutputsLength; ++j) {
252                     jbyteArray expectedOutput = static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
253                     mEnv->ReleaseByteArrayElements(
254                         expectedOutput, reinterpret_cast<jbyte*>(mData[seq_index][i].outputs[j].ptr),
255                         JNI_ABORT);
256                 }
257             }
258             arrayIndex++;
259         }
260     }
261 }
262 
263 extern "C"
264 JNIEXPORT jboolean
265 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_runBenchmark(JNIEnv * env,jobject,jlong _modelHandle,jobject inOutDataList,jobject resultList,jint inferencesSeqMaxCount,jfloat timeoutSec,jint flags)266 Java_com_android_nn_benchmark_core_NNTestBase_runBenchmark(
267         JNIEnv *env,
268         jobject /* this */,
269         jlong _modelHandle,
270         jobject inOutDataList,
271         jobject resultList,
272         jint inferencesSeqMaxCount,
273         jfloat timeoutSec,
274         jint flags) {
275 
276     BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
277 
278     jclass list_class = env->FindClass("java/util/List");
279     if (list_class == nullptr) { return false; }
280     jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
281     if (list_add == nullptr) { return false; }
282 
283     jclass result_class = env->FindClass("com/android/nn/benchmark/core/InferenceResult");
284     if (result_class == nullptr) { return false; }
285     jmethodID result_ctor = env->GetMethodID(result_class, "<init>", "(F[F[F[[BII)V");
286     if (result_ctor == nullptr) { return false; }
287 
288     std::vector<InferenceResult> result;
289 
290     const bool expectGoldenOutputs = (flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0;
291     InferenceInOutSequenceList data(env, inOutDataList, expectGoldenOutputs);
292     if (!data.isValid()) {
293         return false;
294     }
295 
296     // TODO: Remove success boolean from this method and throw an exception in case of problems
297     bool success = model->benchmark(data.data(), inferencesSeqMaxCount, timeoutSec, flags, &result);
298 
299     // Generate results
300     if (success) {
301         for (const InferenceResult &rentry : result) {
302             jobjectArray inferenceOutputs = nullptr;
303             jfloatArray meanSquareErrorArray = nullptr;
304             jfloatArray maxSingleErrorArray = nullptr;
305 
306             if ((flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0) {
307                 meanSquareErrorArray = env->NewFloatArray(rentry.meanSquareErrors.size());
308                 if (env->ExceptionCheck()) { return false; }
309                 maxSingleErrorArray = env->NewFloatArray(rentry.maxSingleErrors.size());
310                 if (env->ExceptionCheck()) { return false; }
311                 {
312                     jfloat *bytes = env->GetFloatArrayElements(meanSquareErrorArray, nullptr);
313                     memcpy(bytes,
314                            &rentry.meanSquareErrors[0],
315                            rentry.meanSquareErrors.size() * sizeof(float));
316                     env->ReleaseFloatArrayElements(meanSquareErrorArray, bytes, 0);
317                 }
318                 {
319                     jfloat *bytes = env->GetFloatArrayElements(maxSingleErrorArray, nullptr);
320                     memcpy(bytes,
321                            &rentry.maxSingleErrors[0],
322                            rentry.maxSingleErrors.size() * sizeof(float));
323                     env->ReleaseFloatArrayElements(maxSingleErrorArray, bytes, 0);
324                 }
325             }
326 
327             if ((flags & FLAG_DISCARD_INFERENCE_OUTPUT) == 0) {
328                 jclass byteArrayClass = env->FindClass("[B");
329 
330                 inferenceOutputs = env->NewObjectArray(
331                     rentry.inferenceOutputs.size(),
332                     byteArrayClass, nullptr);
333 
334                 for (int i = 0;i < rentry.inferenceOutputs.size();++i) {
335                     jbyteArray inferenceOutput = nullptr;
336                     inferenceOutput = env->NewByteArray(rentry.inferenceOutputs[i].size());
337                     if (env->ExceptionCheck()) { return false; }
338                     jbyte *bytes = env->GetByteArrayElements(inferenceOutput, nullptr);
339                     memcpy(bytes, &rentry.inferenceOutputs[i][0], rentry.inferenceOutputs[i].size());
340                     env->ReleaseByteArrayElements(inferenceOutput, bytes, 0);
341                     env->SetObjectArrayElement(inferenceOutputs, i, inferenceOutput);
342                 }
343             }
344 
345             jobject object = env->NewObject(
346                 result_class, result_ctor, rentry.computeTimeSec,
347                 meanSquareErrorArray, maxSingleErrorArray, inferenceOutputs,
348                 rentry.inputOutputSequenceIndex, rentry.inputOutputIndex);
349             if (env->ExceptionCheck() || object == NULL) { return false; }
350 
351             env->CallBooleanMethod(resultList, list_add, object);
352             if (env->ExceptionCheck()) { return false; }
353         }
354     }
355 
356     return success;
357 }
358 
359 extern "C"
360 JNIEXPORT void
361 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_dumpAllLayers(JNIEnv * env,jobject,jlong _modelHandle,jstring dumpPath,jobject inOutDataList)362 Java_com_android_nn_benchmark_core_NNTestBase_dumpAllLayers(
363         JNIEnv *env,
364         jobject /* this */,
365         jlong _modelHandle,
366         jstring dumpPath,
367         jobject inOutDataList) {
368 
369     BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
370 
371     InferenceInOutSequenceList data(env, inOutDataList, /*expectGoldenOutputs=*/false);
372     if (!data.isValid()) {
373         return;
374     }
375 
376     const char *dumpPathStr = env->GetStringUTFChars(dumpPath, JNI_FALSE);
377     model->dumpAllLayers(dumpPathStr, data.data());
378     env->ReleaseStringUTFChars(dumpPath, dumpPathStr);
379 }
380 
381 extern "C"
382 JNIEXPORT jboolean
383 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_hasAccelerator()384 Java_com_android_nn_benchmark_core_NNTestBase_hasAccelerator() {
385   uint32_t device_count = 0;
386   NnApiImplementation()->ANeuralNetworks_getDeviceCount(&device_count);
387   // We only consider a real device, not 'nnapi-reference'.
388   return device_count > 1;
389 }
390