1 /**
2 * Copyright 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "run_tflite.h"
18
19 #include "tensorflow/lite/nnapi/nnapi_implementation.h"
20
21 #include <jni.h>
22 #include <string>
23 #include <iomanip>
24 #include <sstream>
25 #include <fcntl.h>
26
27 #include <android/asset_manager_jni.h>
28 #include <android/log.h>
29 #include <android/sharedmem.h>
30 #include <sys/mman.h>
31
32
33 extern "C"
34 JNIEXPORT jlong
35 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_initModel(JNIEnv * env,jobject,jstring _modelFileName,jboolean _useNnApi,jboolean _enableIntermediateTensorsDump,jstring _nnApiDeviceName)36 Java_com_android_nn_benchmark_core_NNTestBase_initModel(
37 JNIEnv *env,
38 jobject /* this */,
39 jstring _modelFileName,
40 jboolean _useNnApi,
41 jboolean _enableIntermediateTensorsDump,
42 jstring _nnApiDeviceName) {
43 const char *modelFileName = env->GetStringUTFChars(_modelFileName, NULL);
44 const char *nnApiDeviceName =
45 _nnApiDeviceName == NULL
46 ? NULL
47 : env->GetStringUTFChars(_nnApiDeviceName, NULL);
48 void *handle =
49 BenchmarkModel::create(modelFileName, _useNnApi,
50 _enableIntermediateTensorsDump, nnApiDeviceName);
51 env->ReleaseStringUTFChars(_modelFileName, modelFileName);
52 if (_nnApiDeviceName != NULL) {
53 env->ReleaseStringUTFChars(_nnApiDeviceName, nnApiDeviceName);
54 }
55
56 return (jlong)(uintptr_t)handle;
57 }
58
59 extern "C"
60 JNIEXPORT void
61 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_destroyModel(JNIEnv * env,jobject,jlong _modelHandle)62 Java_com_android_nn_benchmark_core_NNTestBase_destroyModel(
63 JNIEnv *env,
64 jobject /* this */,
65 jlong _modelHandle) {
66 BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
67 delete(model);
68 }
69
70 extern "C"
71 JNIEXPORT jboolean
72 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_resizeInputTensors(JNIEnv * env,jobject,jlong _modelHandle,jintArray _inputShape)73 Java_com_android_nn_benchmark_core_NNTestBase_resizeInputTensors(
74 JNIEnv *env,
75 jobject /* this */,
76 jlong _modelHandle,
77 jintArray _inputShape) {
78 BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
79 jint* shapePtr = env->GetIntArrayElements(_inputShape, nullptr);
80 jsize shapeLen = env->GetArrayLength(_inputShape);
81
82 std::vector<int> shape(shapePtr, shapePtr + shapeLen);
83 return model->resizeInputTensors(std::move(shape));
84 }
85
86 /** RAII container for a list of InferenceInOutSequence to handle JNI data release in destructor. */
87 class InferenceInOutSequenceList {
88 public:
89 InferenceInOutSequenceList(JNIEnv *env,
90 const jobject& inOutDataList,
91 bool expectGoldenOutputs);
92 ~InferenceInOutSequenceList();
93
isValid() const94 bool isValid() const { return mValid; }
95
data() const96 const std::vector<InferenceInOutSequence>& data() const { return mData; }
97
98 private:
99 JNIEnv *mEnv; // not owned.
100
101 std::vector<InferenceInOutSequence> mData;
102 std::vector<jbyteArray> mInputArrays;
103 std::vector<jobjectArray> mOutputArrays;
104 bool mValid;
105 };
106
InferenceInOutSequenceList(JNIEnv * env,const jobject & inOutDataList,bool expectGoldenOutputs)107 InferenceInOutSequenceList::InferenceInOutSequenceList(JNIEnv *env,
108 const jobject& inOutDataList,
109 bool expectGoldenOutputs)
110 : mEnv(env), mValid(false) {
111
112 jclass list_class = env->FindClass("java/util/List");
113 if (list_class == nullptr) { return; }
114 jmethodID list_size = env->GetMethodID(list_class, "size", "()I");
115 if (list_size == nullptr) { return; }
116 jmethodID list_get = env->GetMethodID(list_class, "get", "(I)Ljava/lang/Object;");
117 if (list_get == nullptr) { return; }
118 jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
119 if (list_add == nullptr) { return; }
120
121 jclass inOutSeq_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOutSequence");
122 if (inOutSeq_class == nullptr) { return; }
123 jmethodID inOutSeq_size = env->GetMethodID(inOutSeq_class, "size", "()I");
124 if (inOutSeq_size == nullptr) { return; }
125 jmethodID inOutSeq_get = env->GetMethodID(inOutSeq_class, "get",
126 "(I)Lcom/android/nn/benchmark/core/InferenceInOut;");
127 if (inOutSeq_get == nullptr) { return; }
128
129 jclass inout_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut");
130 if (inout_class == nullptr) { return; }
131 jfieldID inout_input = env->GetFieldID(inout_class, "mInput", "[B");
132 if (inout_input == nullptr) { return; }
133 jfieldID inout_expectedOutputs = env->GetFieldID(inout_class, "mExpectedOutputs", "[[B");
134 if (inout_expectedOutputs == nullptr) { return; }
135 jfieldID inout_inputCreator = env->GetFieldID(inout_class, "mInputCreator",
136 "Lcom/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface;");
137 if (inout_inputCreator == nullptr) { return; }
138
139
140
141 // Fetch input/output arrays
142 size_t data_count = mEnv->CallIntMethod(inOutDataList, list_size);
143 if (env->ExceptionCheck()) { return; }
144 mData.reserve(data_count);
145
146 jclass inputCreator_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface");
147 if (inputCreator_class == nullptr) { return; }
148 jmethodID createInput_method = env->GetMethodID(inputCreator_class, "createInput", "(Ljava/nio/ByteBuffer;)V");
149 if (createInput_method == nullptr) { return; }
150
151 for (int seq_index = 0; seq_index < data_count; ++seq_index) {
152 jobject inOutSeq = mEnv->CallObjectMethod(inOutDataList, list_get, seq_index);
153 if (mEnv->ExceptionCheck()) { return; }
154
155 size_t seqLen = mEnv->CallIntMethod(inOutSeq, inOutSeq_size);
156 if (mEnv->ExceptionCheck()) { return; }
157
158 mData.push_back(InferenceInOutSequence{});
159 auto& seq = mData.back();
160 seq.reserve(seqLen);
161 for (int i = 0; i < seqLen; ++i) {
162 jobject inout = mEnv->CallObjectMethod(inOutSeq, inOutSeq_get, i);
163 if (mEnv->ExceptionCheck()) { return; }
164
165 uint8_t* input_data = nullptr;
166 size_t input_len = 0;
167 std::function<bool(uint8_t*, size_t)> inputCreator;
168 jbyteArray input = static_cast<jbyteArray>(
169 mEnv->GetObjectField(inout, inout_input));
170 mInputArrays.push_back(input);
171 if (input != nullptr) {
172 input_data = reinterpret_cast<uint8_t*>(
173 mEnv->GetByteArrayElements(input, NULL));
174 input_len = mEnv->GetArrayLength(input);
175 } else {
176 inputCreator = [env, inout, inout_inputCreator, createInput_method](
177 uint8_t* buffer, size_t length) {
178 jobject byteBuffer = env->NewDirectByteBuffer(buffer, length);
179 if (byteBuffer == nullptr) { return false; }
180 jobject creator = env->GetObjectField(inout, inout_inputCreator);
181 if (creator == nullptr) { return false; }
182 env->CallVoidMethod(creator, createInput_method, byteBuffer);
183 env->DeleteLocalRef(byteBuffer);
184 if (env->ExceptionCheck()) { return false; }
185 return true;
186 };
187 }
188
189 jobjectArray expectedOutputs = static_cast<jobjectArray>(
190 mEnv->GetObjectField(inout, inout_expectedOutputs));
191 mOutputArrays.push_back(expectedOutputs);
192 seq.push_back({input_data, input_len, {}, inputCreator});
193
194 // Add expected output to sequence added above
195 if (expectedOutputs != nullptr) {
196 jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
197 auto& outputs = seq.back().outputs;
198 outputs.reserve(expectedOutputsLength);
199
200 for (jsize j = 0;j < expectedOutputsLength; ++j) {
201 jbyteArray expectedOutput =
202 static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
203 if (env->ExceptionCheck()) {
204 return;
205 }
206 if (expectedOutput == nullptr) {
207 jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
208 mEnv->ThrowNew(iaeClass, "Null expected output array");
209 return;
210 }
211
212 uint8_t *expectedOutput_data = reinterpret_cast<uint8_t*>(
213 mEnv->GetByteArrayElements(expectedOutput, NULL));
214 size_t expectedOutput_len = mEnv->GetArrayLength(expectedOutput);
215 outputs.push_back({ expectedOutput_data, expectedOutput_len});
216 }
217 } else {
218 if (expectGoldenOutputs) {
219 jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
220 mEnv->ThrowNew(iaeClass, "Expected golden output for every input");
221 return;
222 }
223 }
224 }
225 }
226 mValid = true;
227 }
228
~InferenceInOutSequenceList()229 InferenceInOutSequenceList::~InferenceInOutSequenceList() {
230 // Note that we may land here with a pending JNI exception so cannot call
231 // java objects.
232 int arrayIndex = 0;
233 for (int seq_index = 0; seq_index < mData.size(); ++seq_index) {
234 for (int i = 0; i < mData[seq_index].size(); ++i) {
235 jbyteArray input = mInputArrays[arrayIndex];
236 if (input != nullptr) {
237 mEnv->ReleaseByteArrayElements(
238 input, reinterpret_cast<jbyte*>(mData[seq_index][i].input), JNI_ABORT);
239 }
240 jobjectArray expectedOutputs = mOutputArrays[arrayIndex];
241 if (expectedOutputs != nullptr) {
242 jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
243 if (expectedOutputsLength != mData[seq_index][i].outputs.size()) {
244 // Should not happen? :)
245 jclass iaeClass = mEnv->FindClass("java/lang/IllegalStateException");
246 mEnv->ThrowNew(iaeClass, "Mismatch of the size of expected outputs jni array "
247 "and internal array of its bufers");
248 return;
249 }
250
251 for (jsize j = 0;j < expectedOutputsLength; ++j) {
252 jbyteArray expectedOutput = static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
253 mEnv->ReleaseByteArrayElements(
254 expectedOutput, reinterpret_cast<jbyte*>(mData[seq_index][i].outputs[j].ptr),
255 JNI_ABORT);
256 }
257 }
258 arrayIndex++;
259 }
260 }
261 }
262
263 extern "C"
264 JNIEXPORT jboolean
265 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_runBenchmark(JNIEnv * env,jobject,jlong _modelHandle,jobject inOutDataList,jobject resultList,jint inferencesSeqMaxCount,jfloat timeoutSec,jint flags)266 Java_com_android_nn_benchmark_core_NNTestBase_runBenchmark(
267 JNIEnv *env,
268 jobject /* this */,
269 jlong _modelHandle,
270 jobject inOutDataList,
271 jobject resultList,
272 jint inferencesSeqMaxCount,
273 jfloat timeoutSec,
274 jint flags) {
275
276 BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
277
278 jclass list_class = env->FindClass("java/util/List");
279 if (list_class == nullptr) { return false; }
280 jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
281 if (list_add == nullptr) { return false; }
282
283 jclass result_class = env->FindClass("com/android/nn/benchmark/core/InferenceResult");
284 if (result_class == nullptr) { return false; }
285 jmethodID result_ctor = env->GetMethodID(result_class, "<init>", "(F[F[F[[BII)V");
286 if (result_ctor == nullptr) { return false; }
287
288 std::vector<InferenceResult> result;
289
290 const bool expectGoldenOutputs = (flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0;
291 InferenceInOutSequenceList data(env, inOutDataList, expectGoldenOutputs);
292 if (!data.isValid()) {
293 return false;
294 }
295
296 // TODO: Remove success boolean from this method and throw an exception in case of problems
297 bool success = model->benchmark(data.data(), inferencesSeqMaxCount, timeoutSec, flags, &result);
298
299 // Generate results
300 if (success) {
301 for (const InferenceResult &rentry : result) {
302 jobjectArray inferenceOutputs = nullptr;
303 jfloatArray meanSquareErrorArray = nullptr;
304 jfloatArray maxSingleErrorArray = nullptr;
305
306 if ((flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0) {
307 meanSquareErrorArray = env->NewFloatArray(rentry.meanSquareErrors.size());
308 if (env->ExceptionCheck()) { return false; }
309 maxSingleErrorArray = env->NewFloatArray(rentry.maxSingleErrors.size());
310 if (env->ExceptionCheck()) { return false; }
311 {
312 jfloat *bytes = env->GetFloatArrayElements(meanSquareErrorArray, nullptr);
313 memcpy(bytes,
314 &rentry.meanSquareErrors[0],
315 rentry.meanSquareErrors.size() * sizeof(float));
316 env->ReleaseFloatArrayElements(meanSquareErrorArray, bytes, 0);
317 }
318 {
319 jfloat *bytes = env->GetFloatArrayElements(maxSingleErrorArray, nullptr);
320 memcpy(bytes,
321 &rentry.maxSingleErrors[0],
322 rentry.maxSingleErrors.size() * sizeof(float));
323 env->ReleaseFloatArrayElements(maxSingleErrorArray, bytes, 0);
324 }
325 }
326
327 if ((flags & FLAG_DISCARD_INFERENCE_OUTPUT) == 0) {
328 jclass byteArrayClass = env->FindClass("[B");
329
330 inferenceOutputs = env->NewObjectArray(
331 rentry.inferenceOutputs.size(),
332 byteArrayClass, nullptr);
333
334 for (int i = 0;i < rentry.inferenceOutputs.size();++i) {
335 jbyteArray inferenceOutput = nullptr;
336 inferenceOutput = env->NewByteArray(rentry.inferenceOutputs[i].size());
337 if (env->ExceptionCheck()) { return false; }
338 jbyte *bytes = env->GetByteArrayElements(inferenceOutput, nullptr);
339 memcpy(bytes, &rentry.inferenceOutputs[i][0], rentry.inferenceOutputs[i].size());
340 env->ReleaseByteArrayElements(inferenceOutput, bytes, 0);
341 env->SetObjectArrayElement(inferenceOutputs, i, inferenceOutput);
342 }
343 }
344
345 jobject object = env->NewObject(
346 result_class, result_ctor, rentry.computeTimeSec,
347 meanSquareErrorArray, maxSingleErrorArray, inferenceOutputs,
348 rentry.inputOutputSequenceIndex, rentry.inputOutputIndex);
349 if (env->ExceptionCheck() || object == NULL) { return false; }
350
351 env->CallBooleanMethod(resultList, list_add, object);
352 if (env->ExceptionCheck()) { return false; }
353 }
354 }
355
356 return success;
357 }
358
359 extern "C"
360 JNIEXPORT void
361 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_dumpAllLayers(JNIEnv * env,jobject,jlong _modelHandle,jstring dumpPath,jobject inOutDataList)362 Java_com_android_nn_benchmark_core_NNTestBase_dumpAllLayers(
363 JNIEnv *env,
364 jobject /* this */,
365 jlong _modelHandle,
366 jstring dumpPath,
367 jobject inOutDataList) {
368
369 BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
370
371 InferenceInOutSequenceList data(env, inOutDataList, /*expectGoldenOutputs=*/false);
372 if (!data.isValid()) {
373 return;
374 }
375
376 const char *dumpPathStr = env->GetStringUTFChars(dumpPath, JNI_FALSE);
377 model->dumpAllLayers(dumpPathStr, data.data());
378 env->ReleaseStringUTFChars(dumpPath, dumpPathStr);
379 }
380
381 extern "C"
382 JNIEXPORT jboolean
383 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_hasAccelerator()384 Java_com_android_nn_benchmark_core_NNTestBase_hasAccelerator() {
385 uint32_t device_count = 0;
386 NnApiImplementation()->ANeuralNetworks_getDeviceCount(&device_count);
387 // We only consider a real device, not 'nnapi-reference'.
388 return device_count > 1;
389 }
390