1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import android.app.Activity;
20 import android.content.res.AssetManager;
21 import android.os.Build;
22 import android.util.Log;
23 import android.util.Pair;
24 import android.widget.TextView;
25 
26 import java.io.File;
27 import java.io.FileOutputStream;
28 import java.io.IOException;
29 import java.io.InputStream;
30 import java.io.OutputStream;
31 import java.util.ArrayList;
32 import java.util.List;
33 import java.util.Optional;
34 
35 public class NNTestBase {
36     protected static final String TAG = "NN_TESTBASE";
37 
38     // Used to load the 'native-lib' library on application startup.
39     static {
40         System.loadLibrary("nnbenchmark_jni");
41     }
42 
initModel( String modelFileName, boolean useNNApi, boolean enableIntermediateTensorsDump, String nnApiDeviceName)43     private synchronized native long initModel(
44             String modelFileName,
45             boolean useNNApi,
46             boolean enableIntermediateTensorsDump,
47             String nnApiDeviceName);
48 
destroyModel(long modelHandle)49     private synchronized native void destroyModel(long modelHandle);
50 
resizeInputTensors(long modelHandle, int[] inputShape)51     private synchronized native boolean resizeInputTensors(long modelHandle, int[] inputShape);
52 
53     /** Discard inference output in inference results. */
54     public static final int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0;
55     /**
56      * Do not expect golden outputs with inference inputs.
57      *
58      * Useful in cases where there's no straightforward golden output values
59      * for the benchmark. This will also skip calculating basic (golden
60      * output based) error metrics.
61      */
62     public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1;
63 
runBenchmark(long modelHandle, List<InferenceInOutSequence> inOutList, List<InferenceResult> resultList, int inferencesSeqMaxCount, float timeoutSec, int flags)64     private synchronized native boolean runBenchmark(long modelHandle,
65             List<InferenceInOutSequence> inOutList,
66             List<InferenceResult> resultList,
67             int inferencesSeqMaxCount,
68             float timeoutSec,
69             int flags);
70 
dumpAllLayers( long modelHandle, String dumpPath, List<InferenceInOutSequence> inOutList)71     private synchronized native void dumpAllLayers(
72             long modelHandle,
73             String dumpPath,
74             List<InferenceInOutSequence> inOutList);
75 
76     protected Activity mActivity;
77     protected TextView mText;
78     private String mModelName;
79     private String mModelFile;
80     private long mModelHandle;
81     private int[] mInputShape;
82     private InferenceInOutSequence.FromAssets[] mInputOutputAssets;
83     private InferenceInOutSequence.FromDataset[] mInputOutputDatasets;
84     private EvaluatorConfig mEvaluatorConfig;
85     private EvaluatorInterface mEvaluator;
86     private boolean mHasGoldenOutputs;
87     private boolean mUseNNApi = false;
88     private boolean mEnableIntermediateTensorsDump = false;
89     private int mMinSdkVersion;
90     private Optional<String> mNNApiDeviceName = Optional.empty();
91 
NNTestBase(String modelName, String modelFile, int[] inputShape, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets, EvaluatorConfig evaluator, int minSdkVersion)92     public NNTestBase(String modelName, String modelFile, int[] inputShape,
93             InferenceInOutSequence.FromAssets[] inputOutputAssets,
94             InferenceInOutSequence.FromDataset[] inputOutputDatasets,
95             EvaluatorConfig evaluator, int minSdkVersion) {
96         if (inputOutputAssets == null && inputOutputDatasets == null) {
97             throw new IllegalArgumentException(
98                     "Neither inputOutputAssets or inputOutputDatasets given - no inputs");
99         }
100         if (inputOutputAssets != null && inputOutputDatasets != null) {
101             throw new IllegalArgumentException(
102                     "Both inputOutputAssets or inputOutputDatasets given. Only one" +
103                             "supported at once.");
104         }
105         mModelName = modelName;
106         mModelFile = modelFile;
107         mInputShape = inputShape;
108         mInputOutputAssets = inputOutputAssets;
109         mInputOutputDatasets = inputOutputDatasets;
110         mModelHandle = 0;
111         mEvaluatorConfig = evaluator;
112         mMinSdkVersion = minSdkVersion;
113     }
114 
useNNApi()115     public void useNNApi() {
116         useNNApi(true);
117     }
118 
useNNApi(boolean value)119     public void useNNApi(boolean value) {
120         mUseNNApi = value;
121     }
122 
enableIntermediateTensorsDump()123     public void enableIntermediateTensorsDump() {
124         enableIntermediateTensorsDump(true);
125     }
126 
enableIntermediateTensorsDump(boolean value)127     public void enableIntermediateTensorsDump(boolean value) {
128         mEnableIntermediateTensorsDump = value;
129     }
130 
setNNApiDeviceName(String value)131     public void setNNApiDeviceName(String value) {
132         if (!mUseNNApi) {
133             Log.e(TAG, "Setting device name has no effect when not using NNAPI");
134         }
135         mNNApiDeviceName = Optional.ofNullable(value);
136     }
137 
setupModel(Activity ipact)138     public final boolean setupModel(Activity ipact) {
139         mActivity = ipact;
140         String modelFileName = copyAssetToFile();
141         if (modelFileName != null) {
142             mModelHandle = initModel(
143                     modelFileName, mUseNNApi, mEnableIntermediateTensorsDump,
144                     mNNApiDeviceName.orElse(null));
145             if (mModelHandle == 0) {
146                 Log.e(TAG, "Failed to init the model");
147                 return false;
148             }
149             resizeInputTensors(mModelHandle, mInputShape);
150         }
151         if (mEvaluatorConfig != null) {
152             mEvaluator = mEvaluatorConfig.createEvaluator(mActivity.getAssets());
153         }
154         return true;
155     }
156 
getTestInfo()157     public String getTestInfo() {
158         return mModelName;
159     }
160 
getEvaluator()161     public EvaluatorInterface getEvaluator() {
162         return mEvaluator;
163     }
164 
checkSdkVersion()165     public void checkSdkVersion() throws UnsupportedSdkException {
166         if (mMinSdkVersion > 0 && Build.VERSION.SDK_INT < mMinSdkVersion) {
167             throw new UnsupportedSdkException("SDK version not supported. Mininum required: " +
168                     mMinSdkVersion + ", current version: " + Build.VERSION.SDK_INT);
169         }
170     }
171 
getInputOutputAssets()172     private List<InferenceInOutSequence> getInputOutputAssets() throws IOException {
173         // TODO: Caching, don't read inputs for every inference
174         List<InferenceInOutSequence> inOutList = new ArrayList<>();
175         if (mInputOutputAssets != null) {
176             for (InferenceInOutSequence.FromAssets ioAsset : mInputOutputAssets) {
177                 inOutList.add(ioAsset.readAssets(mActivity.getAssets()));
178             }
179         }
180         if (mInputOutputDatasets != null) {
181             for (InferenceInOutSequence.FromDataset dataset : mInputOutputDatasets) {
182                 inOutList.addAll(dataset.readDataset(mActivity.getAssets(),
183                         mActivity.getCacheDir()));
184             }
185         }
186 
187         Boolean lastGolden = null;
188         for (InferenceInOutSequence sequence : inOutList) {
189             mHasGoldenOutputs = sequence.hasGoldenOutput();
190             if (lastGolden == null) {
191                 lastGolden = new Boolean(mHasGoldenOutputs);
192             } else {
193                 if (lastGolden.booleanValue() != mHasGoldenOutputs) {
194                     throw new IllegalArgumentException("Some inputs for " + mModelName +
195                             " have outputs while some don't.");
196                 }
197             }
198         }
199         return inOutList;
200     }
201 
getDefaultFlags()202     public int getDefaultFlags() {
203         int flags = 0;
204         if (!mHasGoldenOutputs) {
205             flags = flags | FLAG_IGNORE_GOLDEN_OUTPUT;
206         }
207         if (mEvaluator == null) {
208             flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT;
209         }
210         return flags;
211     }
212 
dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)213     public void dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)
214             throws IOException {
215         if (!dumpDir.exists() || !dumpDir.isDirectory()) {
216             throw new IllegalArgumentException("dumpDir doesn't exist or is not a directory");
217         }
218         if (!mEnableIntermediateTensorsDump) {
219             throw new IllegalStateException("mEnableIntermediateTensorsDump is " +
220                     "set to false, impossible to proceed");
221         }
222 
223         List<InferenceInOutSequence> ios = getInputOutputAssets();
224         dumpAllLayers(mModelHandle, dumpDir.toString(),
225                 ios.subList(inputAssetIndex, inputAssetSize));
226     }
227 
runInferenceOnce()228     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runInferenceOnce()
229             throws IOException, BenchmarkException {
230         List<InferenceInOutSequence> ios = getInputOutputAssets();
231         int flags = getDefaultFlags();
232         Pair<List<InferenceInOutSequence>, List<InferenceResult>> output =
233                 runBenchmark(ios, 1, Float.MAX_VALUE, flags);
234         return output;
235     }
236 
runBenchmark(float timeoutSec)237     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(float timeoutSec)
238             throws IOException, BenchmarkException {
239         // Run as many as possible before timeout.
240         int flags = getDefaultFlags();
241         return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags);
242     }
243 
244     /** Run through whole input set (once or mutliple times). */
runBenchmarkCompleteInputSet( int setRepeat, float timeoutSec)245     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet(
246             int setRepeat,
247             float timeoutSec)
248             throws IOException, BenchmarkException {
249         int flags = getDefaultFlags();
250         List<InferenceInOutSequence> ios = getInputOutputAssets();
251         int totalSequenceInferencesCount = ios.size() * setRepeat;
252         int extpectedResults = 0;
253         for (InferenceInOutSequence iosSeq : ios) {
254             extpectedResults += iosSeq.size();
255         }
256         extpectedResults *= setRepeat;
257 
258         Pair<List<InferenceInOutSequence>, List<InferenceResult>> result =
259                 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec,
260                         flags);
261         if (result.second.size() != extpectedResults ) {
262             // We reached a timeout or failed to evaluate whole set for other reason, abort.
263             throw new IllegalStateException(
264                     "Failed to evaluate complete input set, expected: "
265                             + extpectedResults +
266                             ", received: " + result.second.size());
267         }
268         return result;
269     }
270 
runBenchmark( List<InferenceInOutSequence> inOutList, int inferencesSeqMaxCount, float timeoutSec, int flags)271     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(
272             List<InferenceInOutSequence> inOutList,
273             int inferencesSeqMaxCount,
274             float timeoutSec,
275             int flags)
276             throws IOException, BenchmarkException {
277         if (mModelHandle == 0) {
278             throw new BenchmarkException("Unsupported model");
279         }
280         List<InferenceResult> resultList = new ArrayList<>();
281         if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount,
282                     timeoutSec, flags)) {
283             throw new BenchmarkException("Failed to run benchmark");
284         }
285         return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>(
286                 inOutList, resultList);
287     }
288 
destroy()289     public void destroy() {
290         if (mModelHandle != 0) {
291             destroyModel(mModelHandle);
292             mModelHandle = 0;
293         }
294     }
295 
296     // We need to copy it to cache dir, so that TFlite can load it directly.
copyAssetToFile()297     private String copyAssetToFile() {
298         String outFileName;
299         String modelAssetName = mModelFile + ".tflite";
300         AssetManager assetManager = mActivity.getAssets();
301         try {
302             InputStream in = assetManager.open(modelAssetName);
303 
304             outFileName = mActivity.getCacheDir().getAbsolutePath() + "/" + modelAssetName;
305             File outFile = new File(outFileName);
306             OutputStream out = new FileOutputStream(outFile);
307 
308             byte[] buffer = new byte[1024];
309             int read;
310             while ((read = in.read(buffer)) != -1) {
311                 out.write(buffer, 0, read);
312             }
313             out.flush();
314 
315             in.close();
316             out.close();
317         } catch (IOException e) {
318             Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e);
319             return null;
320         }
321         return outFileName;
322     }
323 }
324