1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import android.content.Context;
20 import android.content.res.AssetManager;
21 import android.os.Build;
22 import android.util.Log;
23 import android.util.Pair;
24 import android.widget.TextView;
25 
26 import java.io.File;
27 import java.io.FileOutputStream;
28 import java.io.IOException;
29 import java.io.InputStream;
30 import java.io.OutputStream;
31 import java.util.ArrayList;
32 import java.util.List;
33 import java.util.Optional;
34 
35 public class NNTestBase {
36     protected static final String TAG = "NN_TESTBASE";
37 
38     // Used to load the 'native-lib' library on application startup.
39     static {
40         System.loadLibrary("nnbenchmark_jni");
41     }
42 
43     // Does the device has any NNAPI accelerator?
44     // We only consider a real device, not 'nnapi-reference'.
hasAccelerator()45     public static native boolean hasAccelerator();
46 
initModel( String modelFileName, boolean useNNApi, boolean enableIntermediateTensorsDump, String nnApiDeviceName)47     private synchronized native long initModel(
48             String modelFileName,
49             boolean useNNApi,
50             boolean enableIntermediateTensorsDump,
51             String nnApiDeviceName);
52 
destroyModel(long modelHandle)53     private synchronized native void destroyModel(long modelHandle);
54 
resizeInputTensors(long modelHandle, int[] inputShape)55     private synchronized native boolean resizeInputTensors(long modelHandle, int[] inputShape);
56 
57     /** Discard inference output in inference results. */
58     public static final int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0;
59     /**
60      * Do not expect golden outputs with inference inputs.
61      *
62      * Useful in cases where there's no straightforward golden output values
63      * for the benchmark. This will also skip calculating basic (golden
64      * output based) error metrics.
65      */
66     public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1;
67 
runBenchmark(long modelHandle, List<InferenceInOutSequence> inOutList, List<InferenceResult> resultList, int inferencesSeqMaxCount, float timeoutSec, int flags)68     private synchronized native boolean runBenchmark(long modelHandle,
69             List<InferenceInOutSequence> inOutList,
70             List<InferenceResult> resultList,
71             int inferencesSeqMaxCount,
72             float timeoutSec,
73             int flags);
74 
dumpAllLayers( long modelHandle, String dumpPath, List<InferenceInOutSequence> inOutList)75     private synchronized native void dumpAllLayers(
76             long modelHandle,
77             String dumpPath,
78             List<InferenceInOutSequence> inOutList);
79 
80     protected Context mContext;
81     protected TextView mText;
82     private String mModelName;
83     private String mModelFile;
84     private long mModelHandle;
85     private int[] mInputShape;
86     private InferenceInOutSequence.FromAssets[] mInputOutputAssets;
87     private InferenceInOutSequence.FromDataset[] mInputOutputDatasets;
88     private EvaluatorConfig mEvaluatorConfig;
89     private EvaluatorInterface mEvaluator;
90     private boolean mHasGoldenOutputs;
91     private boolean mUseNNApi = false;
92     private boolean mEnableIntermediateTensorsDump = false;
93     private int mMinSdkVersion;
94     private Optional<String> mNNApiDeviceName = Optional.empty();
95 
NNTestBase(String modelName, String modelFile, int[] inputShape, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets, EvaluatorConfig evaluator, int minSdkVersion)96     public NNTestBase(String modelName, String modelFile, int[] inputShape,
97             InferenceInOutSequence.FromAssets[] inputOutputAssets,
98             InferenceInOutSequence.FromDataset[] inputOutputDatasets,
99             EvaluatorConfig evaluator, int minSdkVersion) {
100         if (inputOutputAssets == null && inputOutputDatasets == null) {
101             throw new IllegalArgumentException(
102                     "Neither inputOutputAssets or inputOutputDatasets given - no inputs");
103         }
104         if (inputOutputAssets != null && inputOutputDatasets != null) {
105             throw new IllegalArgumentException(
106                     "Both inputOutputAssets or inputOutputDatasets given. Only one" +
107                             "supported at once.");
108         }
109         mModelName = modelName;
110         mModelFile = modelFile;
111         mInputShape = inputShape;
112         mInputOutputAssets = inputOutputAssets;
113         mInputOutputDatasets = inputOutputDatasets;
114         mModelHandle = 0;
115         mEvaluatorConfig = evaluator;
116         mMinSdkVersion = minSdkVersion;
117     }
118 
useNNApi()119     public void useNNApi() {
120         useNNApi(true);
121     }
122 
useNNApi(boolean value)123     public void useNNApi(boolean value) {
124         mUseNNApi = value;
125     }
126 
enableIntermediateTensorsDump()127     public void enableIntermediateTensorsDump() {
128         enableIntermediateTensorsDump(true);
129     }
130 
enableIntermediateTensorsDump(boolean value)131     public void enableIntermediateTensorsDump(boolean value) {
132         mEnableIntermediateTensorsDump = value;
133     }
134 
setNNApiDeviceName(String value)135     public void setNNApiDeviceName(String value) {
136         if (!mUseNNApi) {
137             Log.e(TAG, "Setting device name has no effect when not using NNAPI");
138         }
139         mNNApiDeviceName = Optional.ofNullable(value);
140     }
141 
setupModel(Context ipcxt)142     public final boolean setupModel(Context ipcxt) {
143         mContext = ipcxt;
144         String modelFileName = copyAssetToFile();
145         if (modelFileName != null) {
146             mModelHandle = initModel(
147                     modelFileName, mUseNNApi, mEnableIntermediateTensorsDump,
148                     mNNApiDeviceName.orElse(null));
149             if (mModelHandle == 0) {
150                 Log.e(TAG, "Failed to init the model");
151                 return false;
152             }
153             resizeInputTensors(mModelHandle, mInputShape);
154         }
155         if (mEvaluatorConfig != null) {
156             mEvaluator = mEvaluatorConfig.createEvaluator(mContext.getAssets());
157         }
158         return true;
159     }
160 
getTestInfo()161     public String getTestInfo() {
162         return mModelName;
163     }
164 
getEvaluator()165     public EvaluatorInterface getEvaluator() {
166         return mEvaluator;
167     }
168 
checkSdkVersion()169     public void checkSdkVersion() throws UnsupportedSdkException {
170         if (mMinSdkVersion > 0 && Build.VERSION.SDK_INT < mMinSdkVersion) {
171             throw new UnsupportedSdkException("SDK version not supported. Mininum required: " +
172                     mMinSdkVersion + ", current version: " + Build.VERSION.SDK_INT);
173         }
174     }
175 
getInputOutputAssets()176     private List<InferenceInOutSequence> getInputOutputAssets() throws IOException {
177         // TODO: Caching, don't read inputs for every inference
178         List<InferenceInOutSequence> inOutList = new ArrayList<>();
179         if (mInputOutputAssets != null) {
180             for (InferenceInOutSequence.FromAssets ioAsset : mInputOutputAssets) {
181                 inOutList.add(ioAsset.readAssets(mContext.getAssets()));
182             }
183         }
184         if (mInputOutputDatasets != null) {
185             for (InferenceInOutSequence.FromDataset dataset : mInputOutputDatasets) {
186                 inOutList.addAll(dataset.readDataset(mContext.getAssets(),
187                         mContext.getCacheDir()));
188             }
189         }
190 
191         Boolean lastGolden = null;
192         for (InferenceInOutSequence sequence : inOutList) {
193             mHasGoldenOutputs = sequence.hasGoldenOutput();
194             if (lastGolden == null) {
195                 lastGolden = new Boolean(mHasGoldenOutputs);
196             } else {
197                 if (lastGolden.booleanValue() != mHasGoldenOutputs) {
198                     throw new IllegalArgumentException("Some inputs for " + mModelName +
199                             " have outputs while some don't.");
200                 }
201             }
202         }
203         return inOutList;
204     }
205 
getDefaultFlags()206     public int getDefaultFlags() {
207         int flags = 0;
208         if (!mHasGoldenOutputs) {
209             flags = flags | FLAG_IGNORE_GOLDEN_OUTPUT;
210         }
211         if (mEvaluator == null) {
212             flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT;
213         }
214         return flags;
215     }
216 
dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)217     public void dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)
218             throws IOException {
219         if (!dumpDir.exists() || !dumpDir.isDirectory()) {
220             throw new IllegalArgumentException("dumpDir doesn't exist or is not a directory");
221         }
222         if (!mEnableIntermediateTensorsDump) {
223             throw new IllegalStateException("mEnableIntermediateTensorsDump is " +
224                     "set to false, impossible to proceed");
225         }
226 
227         List<InferenceInOutSequence> ios = getInputOutputAssets();
228         dumpAllLayers(mModelHandle, dumpDir.toString(),
229                 ios.subList(inputAssetIndex, inputAssetSize));
230     }
231 
runInferenceOnce()232     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runInferenceOnce()
233             throws IOException, BenchmarkException {
234         List<InferenceInOutSequence> ios = getInputOutputAssets();
235         int flags = getDefaultFlags();
236         Pair<List<InferenceInOutSequence>, List<InferenceResult>> output =
237                 runBenchmark(ios, 1, Float.MAX_VALUE, flags);
238         return output;
239     }
240 
runBenchmark(float timeoutSec)241     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(float timeoutSec)
242             throws IOException, BenchmarkException {
243         // Run as many as possible before timeout.
244         int flags = getDefaultFlags();
245         return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags);
246     }
247 
248     /** Run through whole input set (once or mutliple times). */
runBenchmarkCompleteInputSet( int setRepeat, float timeoutSec)249     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet(
250             int setRepeat,
251             float timeoutSec)
252             throws IOException, BenchmarkException {
253         int flags = getDefaultFlags();
254         List<InferenceInOutSequence> ios = getInputOutputAssets();
255         int totalSequenceInferencesCount = ios.size() * setRepeat;
256         int extpectedResults = 0;
257         for (InferenceInOutSequence iosSeq : ios) {
258             extpectedResults += iosSeq.size();
259         }
260         extpectedResults *= setRepeat;
261 
262         Pair<List<InferenceInOutSequence>, List<InferenceResult>> result =
263                 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec,
264                         flags);
265         if (result.second.size() != extpectedResults) {
266             // We reached a timeout or failed to evaluate whole set for other reason, abort.
267             final String errorMsg = "Failed to evaluate complete input set, expected: "
268                     + extpectedResults +
269                     ", received: " + result.second.size();
270             Log.w(TAG, errorMsg);
271             throw new IllegalStateException(errorMsg);
272         }
273         return result;
274     }
275 
runBenchmark( List<InferenceInOutSequence> inOutList, int inferencesSeqMaxCount, float timeoutSec, int flags)276     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(
277             List<InferenceInOutSequence> inOutList,
278             int inferencesSeqMaxCount,
279             float timeoutSec,
280             int flags)
281             throws IOException, BenchmarkException {
282         if (mModelHandle == 0) {
283             throw new BenchmarkException("Unsupported model");
284         }
285         List<InferenceResult> resultList = new ArrayList<>();
286         if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount,
287                 timeoutSec, flags)) {
288             throw new BenchmarkException("Failed to run benchmark");
289         }
290         return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>(
291                 inOutList, resultList);
292     }
293 
destroy()294     public void destroy() {
295         if (mModelHandle != 0) {
296             destroyModel(mModelHandle);
297             mModelHandle = 0;
298         }
299     }
300 
301     // We need to copy it to cache dir, so that TFlite can load it directly.
copyAssetToFile()302     private String copyAssetToFile() {
303         String outFileName;
304         String modelAssetName = mModelFile + ".tflite";
305         AssetManager assetManager = mContext.getAssets();
306         try {
307             outFileName = mContext.getCacheDir().getAbsolutePath() + "/" + modelAssetName;
308             File outFile = new File(outFileName);
309 
310             try (InputStream in = assetManager.open(modelAssetName);
311                  FileOutputStream out = new FileOutputStream(outFile)) {
312 
313                 byte[] byteBuffer = new byte[1024];
314                 int readBytes = -1;
315                 while ((readBytes = in.read(byteBuffer)) != -1) {
316                     out.write(byteBuffer, 0, readBytes);
317                 }
318             }
319         } catch (IOException e) {
320             Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e);
321             return null;
322         }
323         return outFileName;
324     }
325 }
326