1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import android.annotation.SuppressLint;
20 import android.content.Context;
21 import android.content.res.AssetManager;
22 import android.os.Build;
23 import android.util.Log;
24 import android.util.Pair;
25 import android.widget.TextView;
26 
27 import java.io.File;
28 import java.io.FileOutputStream;
29 import java.io.IOException;
30 import java.io.InputStream;
31 import java.util.ArrayList;
32 import java.util.Collections;
33 import java.util.List;
34 import java.util.Optional;
35 import java.util.Random;
36 import java.util.stream.Collectors;
37 
38 public class NNTestBase implements AutoCloseable {
39     protected static final String TAG = "NN_TESTBASE";
40 
41     // Used to load the 'native-lib' library on application startup.
42     static {
43         System.loadLibrary("nnbenchmark_jni");
44     }
45 
46     // Does the device has any NNAPI accelerator?
47     // We only consider a real device, not 'nnapi-reference'.
hasAccelerator()48     public static native boolean hasAccelerator();
49 
50     /**
51      * Fills resultList with the name of the available NNAPI accelerators
52      *
53      * @return False if any error occurred, true otherwise
54      */
getAcceleratorNames(List<String> resultList)55     public static native boolean getAcceleratorNames(List<String> resultList);
hasNnApiDevice(String nnApiDeviceName)56     public static native boolean hasNnApiDevice(String nnApiDeviceName);
57 
initModel( String modelFileName, int tfliteBackend, boolean enableIntermediateTensorsDump, String nnApiDeviceName, boolean mmapModel, String nnApiCacheDir)58     private synchronized native long initModel(
59             String modelFileName,
60             int tfliteBackend,
61             boolean enableIntermediateTensorsDump,
62             String nnApiDeviceName,
63             boolean mmapModel,
64             String nnApiCacheDir) throws NnApiDelegationFailure;
65 
destroyModel(long modelHandle)66     private synchronized native void destroyModel(long modelHandle);
67 
resizeInputTensors(long modelHandle, int[] inputShape)68     private synchronized native boolean resizeInputTensors(long modelHandle, int[] inputShape);
69 
runBenchmark(long modelHandle, List<InferenceInOutSequence> inOutList, List<InferenceResult> resultList, int inferencesSeqMaxCount, float timeoutSec, int flags)70     private synchronized native boolean runBenchmark(long modelHandle,
71             List<InferenceInOutSequence> inOutList,
72             List<InferenceResult> resultList,
73             int inferencesSeqMaxCount,
74             float timeoutSec,
75             int flags);
76 
runCompilationBenchmark( long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec)77     private synchronized native CompilationBenchmarkResult runCompilationBenchmark(
78             long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec);
79 
dumpAllLayers( long modelHandle, String dumpPath, List<InferenceInOutSequence> inOutList)80     private synchronized native void dumpAllLayers(
81             long modelHandle,
82             String dumpPath,
83             List<InferenceInOutSequence> inOutList);
84 
availableAcceleratorNames()85     public static List<String> availableAcceleratorNames() {
86         List<String> availableAccelerators = new ArrayList<>();
87         if (NNTestBase.getAcceleratorNames(availableAccelerators)) {
88             return availableAccelerators.stream().filter(
89                     acceleratorName -> !acceleratorName.equalsIgnoreCase(
90                             "nnapi-reference")).collect(Collectors.toList());
91         } else {
92             Log.e(TAG, "Unable to retrieve accelerator names!!");
93             return Collections.EMPTY_LIST;
94         }
95     }
96 
97     /** Discard inference output in inference results. */
98     public static final int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0;
99     /**
100      * Do not expect golden outputs with inference inputs.
101      *
102      * Useful in cases where there's no straightforward golden output values
103      * for the benchmark. This will also skip calculating basic (golden
104      * output based) error metrics.
105      */
106     public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1;
107 
108 
109     /** Collect only 1 benchmark result every 10 **/
110     public static final int FLAG_SAMPLE_BENCHMARK_RESULTS = 1 << 2;
111 
112     protected Context mContext;
113     protected TextView mText;
114     private final String mModelName;
115     private final String mModelFile;
116     private long mModelHandle;
117     private final int[] mInputShape;
118     private final InferenceInOutSequence.FromAssets[] mInputOutputAssets;
119     private final InferenceInOutSequence.FromDataset[] mInputOutputDatasets;
120     private final EvaluatorConfig mEvaluatorConfig;
121     private EvaluatorInterface mEvaluator;
122     private boolean mHasGoldenOutputs;
123     private TfLiteBackend mTfLiteBackend;
124     private boolean mEnableIntermediateTensorsDump = false;
125     private final int mMinSdkVersion;
126     private Optional<String> mNNApiDeviceName = Optional.empty();
127     private boolean mMmapModel = false;
128     // Path where the current model has been stored for execution
129     private String mTemporaryModelFilePath;
130     private boolean mSampleResults;
131 
NNTestBase(String modelName, String modelFile, int[] inputShape, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets, EvaluatorConfig evaluator, int minSdkVersion)132     public NNTestBase(String modelName, String modelFile, int[] inputShape,
133             InferenceInOutSequence.FromAssets[] inputOutputAssets,
134             InferenceInOutSequence.FromDataset[] inputOutputDatasets,
135             EvaluatorConfig evaluator, int minSdkVersion) {
136         if (inputOutputAssets == null && inputOutputDatasets == null) {
137             throw new IllegalArgumentException(
138                     "Neither inputOutputAssets or inputOutputDatasets given - no inputs");
139         }
140         if (inputOutputAssets != null && inputOutputDatasets != null) {
141             throw new IllegalArgumentException(
142                     "Both inputOutputAssets or inputOutputDatasets given. Only one" +
143                             "supported at once.");
144         }
145         mModelName = modelName;
146         mModelFile = modelFile;
147         mInputShape = inputShape;
148         mInputOutputAssets = inputOutputAssets;
149         mInputOutputDatasets = inputOutputDatasets;
150         mModelHandle = 0;
151         mEvaluatorConfig = evaluator;
152         mMinSdkVersion = minSdkVersion;
153         mSampleResults = false;
154     }
155 
setTfLiteBackend(TfLiteBackend tfLiteBackend)156     public void setTfLiteBackend(TfLiteBackend tfLiteBackend) {
157         mTfLiteBackend = tfLiteBackend;
158     }
159 
enableIntermediateTensorsDump()160     public void enableIntermediateTensorsDump() {
161         enableIntermediateTensorsDump(true);
162     }
163 
enableIntermediateTensorsDump(boolean value)164     public void enableIntermediateTensorsDump(boolean value) {
165         mEnableIntermediateTensorsDump = value;
166     }
167 
useNNApi()168     public void useNNApi() {
169       setTfLiteBackend(TfLiteBackend.NNAPI);
170     }
171 
setNNApiDeviceName(String value)172     public void setNNApiDeviceName(String value) {
173         if (mTfLiteBackend != TfLiteBackend.NNAPI) {
174             Log.e(TAG, "Setting device name has no effect when not using NNAPI");
175         }
176         mNNApiDeviceName = Optional.ofNullable(value);
177     }
178 
setMmapModel(boolean value)179     public void setMmapModel(boolean value) {
180         mMmapModel = value;
181     }
182 
setupModel(Context ipcxt)183     public final boolean setupModel(Context ipcxt) throws IOException, NnApiDelegationFailure {
184         mContext = ipcxt;
185         if (mTemporaryModelFilePath != null) {
186             deleteOrWarn(mTemporaryModelFilePath);
187         }
188         mTemporaryModelFilePath = copyAssetToFile();
189         String nnApiCacheDir = mContext.getCodeCacheDir().toString();
190         mModelHandle = initModel(
191                 mTemporaryModelFilePath, mTfLiteBackend.ordinal(), mEnableIntermediateTensorsDump,
192                 mNNApiDeviceName.orElse(null), mMmapModel, nnApiCacheDir);
193         if (mModelHandle == 0) {
194             Log.e(TAG, "Failed to init the model");
195             return false;
196         }
197         if (!resizeInputTensors(mModelHandle, mInputShape)) {
198             return false;
199         }
200 
201         if (mEvaluatorConfig != null) {
202             mEvaluator = mEvaluatorConfig.createEvaluator(mContext.getAssets());
203         }
204         return true;
205     }
206 
getTestInfo()207     public String getTestInfo() {
208         return mModelName;
209     }
210 
getEvaluator()211     public EvaluatorInterface getEvaluator() {
212         return mEvaluator;
213     }
214 
checkSdkVersion()215     public void checkSdkVersion() throws UnsupportedSdkException {
216         if (mMinSdkVersion > 0 && Build.VERSION.SDK_INT < mMinSdkVersion) {
217             throw new UnsupportedSdkException("SDK version not supported. Mininum required: " +
218                     mMinSdkVersion + ", current version: " + Build.VERSION.SDK_INT);
219         }
220     }
221 
deleteOrWarn(String path)222     private void deleteOrWarn(String path) {
223         if (!new File(path).delete()) {
224             Log.w(TAG, String.format(
225                     "Unable to delete file '%s'. This might cause device to run out of space.",
226                     path));
227         }
228     }
229 
230 
getInputOutputAssets()231     private List<InferenceInOutSequence> getInputOutputAssets() throws IOException {
232         // TODO: Caching, don't read inputs for every inference
233         List<InferenceInOutSequence> inOutList =
234                 getInputOutputAssets(mContext, mInputOutputAssets, mInputOutputDatasets);
235 
236         Boolean lastGolden = null;
237         for (InferenceInOutSequence sequence : inOutList) {
238             mHasGoldenOutputs = sequence.hasGoldenOutput();
239             if (lastGolden == null) {
240                 lastGolden = mHasGoldenOutputs;
241             } else {
242                 if (lastGolden != mHasGoldenOutputs) {
243                     throw new IllegalArgumentException(
244                             "Some inputs for " + mModelName + " have outputs while some don't.");
245                 }
246             }
247         }
248         return inOutList;
249     }
250 
getInputOutputAssets(Context context, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets)251     public static List<InferenceInOutSequence> getInputOutputAssets(Context context,
252             InferenceInOutSequence.FromAssets[] inputOutputAssets,
253             InferenceInOutSequence.FromDataset[] inputOutputDatasets) throws IOException {
254         // TODO: Caching, don't read inputs for every inference
255         List<InferenceInOutSequence> inOutList = new ArrayList<>();
256         if (inputOutputAssets != null) {
257             for (InferenceInOutSequence.FromAssets ioAsset : inputOutputAssets) {
258                 inOutList.add(ioAsset.readAssets(context.getAssets()));
259             }
260         }
261         if (inputOutputDatasets != null) {
262             for (InferenceInOutSequence.FromDataset dataset : inputOutputDatasets) {
263                 inOutList.addAll(dataset.readDataset(context.getAssets(), context.getCacheDir()));
264             }
265         }
266 
267         return inOutList;
268     }
269 
getDefaultFlags()270     public int getDefaultFlags() {
271         int flags = 0;
272         if (!mHasGoldenOutputs) {
273             flags = flags | FLAG_IGNORE_GOLDEN_OUTPUT;
274         }
275         if (mEvaluator == null) {
276             flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT;
277         }
278         // For very long tests we will collect only a sample of the results
279         if (mSampleResults) {
280             flags = flags | FLAG_SAMPLE_BENCHMARK_RESULTS;
281         }
282         return flags;
283     }
284 
dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)285     public void dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)
286             throws IOException {
287         if (!dumpDir.exists() || !dumpDir.isDirectory()) {
288             throw new IllegalArgumentException("dumpDir doesn't exist or is not a directory");
289         }
290         if (!mEnableIntermediateTensorsDump) {
291             throw new IllegalStateException("mEnableIntermediateTensorsDump is " +
292                     "set to false, impossible to proceed");
293         }
294 
295         List<InferenceInOutSequence> ios = getInputOutputAssets();
296         dumpAllLayers(mModelHandle, dumpDir.toString(),
297                 ios.subList(inputAssetIndex, inputAssetSize));
298     }
299 
runInferenceOnce()300     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runInferenceOnce()
301             throws IOException, BenchmarkException {
302         List<InferenceInOutSequence> ios = getInputOutputAssets();
303         int flags = getDefaultFlags();
304         Pair<List<InferenceInOutSequence>, List<InferenceResult>> output =
305                 runBenchmark(ios, 1, Float.MAX_VALUE, flags);
306         return output;
307     }
308 
runBenchmark(float timeoutSec)309     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(float timeoutSec)
310             throws IOException, BenchmarkException {
311         // Run as many as possible before timeout.
312         int flags = getDefaultFlags();
313         return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags);
314     }
315 
316     /** Run through whole input set (once or multiple times). */
runBenchmarkCompleteInputSet( int minInferences, float timeoutSec)317     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet(
318             int minInferences,
319             float timeoutSec)
320             throws IOException, BenchmarkException {
321         int flags = getDefaultFlags();
322         List<InferenceInOutSequence> ios = getInputOutputAssets();
323         int setInferences = 0;
324         for (InferenceInOutSequence iosSeq : ios) {
325             setInferences += iosSeq.size();
326         }
327         int setRepeat = (minInferences + setInferences - 1) / setInferences; // ceil.
328         int totalSequenceInferencesCount = ios.size() * setRepeat;
329         int expectedResults = setInferences * setRepeat;
330 
331         Pair<List<InferenceInOutSequence>, List<InferenceResult>> result =
332                 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec,
333                         flags);
334         if (result.second.size() != expectedResults) {
335             // We reached a timeout or failed to evaluate whole set for other reason, abort.
336             @SuppressLint("DefaultLocale")
337             final String errorMsg = String.format(
338                     "Failed to evaluate complete input set, in %f seconds expected: %d, received:"
339                             + " %d",
340                     timeoutSec, expectedResults, result.second.size());
341             Log.w(TAG, errorMsg);
342             throw new IllegalStateException(errorMsg);
343         }
344         return result;
345     }
346 
runBenchmark( List<InferenceInOutSequence> inOutList, int inferencesSeqMaxCount, float timeoutSec, int flags)347     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(
348             List<InferenceInOutSequence> inOutList,
349             int inferencesSeqMaxCount,
350             float timeoutSec,
351             int flags)
352             throws IOException, BenchmarkException {
353         if (mModelHandle == 0) {
354             throw new UnsupportedModelException("Unsupported model");
355         }
356         List<InferenceResult> resultList = new ArrayList<>();
357         if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount,
358                 timeoutSec, flags)) {
359             throw new BenchmarkException("Failed to run benchmark");
360         }
361         return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>(
362                 inOutList, resultList);
363     }
364 
runCompilationBenchmark(float warmupTimeoutSec, float runTimeoutSec, int maxIterations)365     public CompilationBenchmarkResult runCompilationBenchmark(float warmupTimeoutSec,
366             float runTimeoutSec, int maxIterations) throws IOException, BenchmarkException {
367         if (mModelHandle == 0) {
368             throw new UnsupportedModelException("Unsupported model");
369         }
370         CompilationBenchmarkResult result = runCompilationBenchmark(
371                 mModelHandle, maxIterations, warmupTimeoutSec, runTimeoutSec);
372         if (result == null) {
373             throw new BenchmarkException("Failed to run compilation benchmark");
374         }
375         return result;
376     }
377 
destroy()378     public void destroy() {
379         if (mModelHandle != 0) {
380             destroyModel(mModelHandle);
381             mModelHandle = 0;
382         }
383         if (mTemporaryModelFilePath != null) {
384             deleteOrWarn(mTemporaryModelFilePath);
385             mTemporaryModelFilePath = null;
386         }
387     }
388 
389     private final Random mRandom = new Random(System.currentTimeMillis());
390 
391     // We need to copy it to cache dir, so that TFlite can load it directly.
copyAssetToFile()392     private String copyAssetToFile() throws IOException {
393         @SuppressLint("DefaultLocale")
394         String outFileName =
395                 String.format("%s/%s-%d-%d.tflite", mContext.getCacheDir().getAbsolutePath(),
396                         mModelFile,
397                         Thread.currentThread().getId(), mRandom.nextInt(10000));
398 
399         copyAssetToFile(mContext, mModelFile + ".tflite", outFileName);
400         return outFileName;
401     }
402 
copyModelToFile(Context context, String modelFileName, File targetFile)403     public static boolean copyModelToFile(Context context, String modelFileName, File targetFile)
404             throws IOException {
405         if (!targetFile.exists() && !targetFile.createNewFile()) {
406             Log.w(TAG, String.format("Unable to create file %s", targetFile.getAbsolutePath()));
407             return false;
408         }
409         NNTestBase.copyAssetToFile(context, modelFileName, targetFile.getAbsolutePath());
410         return true;
411     }
412 
copyAssetToFile(Context context, String modelAssetName, String targetPath)413     public static void copyAssetToFile(Context context, String modelAssetName, String targetPath)
414             throws IOException {
415         AssetManager assetManager = context.getAssets();
416         try {
417             File outFile = new File(targetPath);
418 
419             try (InputStream in = assetManager.open(modelAssetName);
420                  FileOutputStream out = new FileOutputStream(outFile)) {
421                 byte[] byteBuffer = new byte[1024];
422                 int readBytes = -1;
423                 while ((readBytes = in.read(byteBuffer)) != -1) {
424                     out.write(byteBuffer, 0, readBytes);
425                 }
426             }
427         } catch (IOException e) {
428             Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e);
429             throw e;
430         }
431     }
432 
433     @Override
close()434     public void close() {
435         destroy();
436     }
437 
setSampleResult(boolean sampleResults)438     public void setSampleResult(boolean sampleResults) {
439         this.mSampleResults = sampleResults;
440     }
441 }
442