1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.content.Context; 20 import android.content.res.AssetManager; 21 import android.os.Build; 22 import android.util.Log; 23 import android.util.Pair; 24 import android.widget.TextView; 25 26 import java.io.File; 27 import java.io.FileOutputStream; 28 import java.io.IOException; 29 import java.io.InputStream; 30 import java.io.OutputStream; 31 import java.util.ArrayList; 32 import java.util.List; 33 import java.util.Optional; 34 35 public class NNTestBase { 36 protected static final String TAG = "NN_TESTBASE"; 37 38 // Used to load the 'native-lib' library on application startup. 39 static { 40 System.loadLibrary("nnbenchmark_jni"); 41 } 42 43 // Does the device has any NNAPI accelerator? 44 // We only consider a real device, not 'nnapi-reference'. hasAccelerator()45 public static native boolean hasAccelerator(); 46 initModel( String modelFileName, boolean useNNApi, boolean enableIntermediateTensorsDump, String nnApiDeviceName)47 private synchronized native long initModel( 48 String modelFileName, 49 boolean useNNApi, 50 boolean enableIntermediateTensorsDump, 51 String nnApiDeviceName); 52 destroyModel(long modelHandle)53 private synchronized native void destroyModel(long modelHandle); 54 resizeInputTensors(long modelHandle, int[] inputShape)55 private synchronized native boolean resizeInputTensors(long modelHandle, int[] inputShape); 56 57 /** Discard inference output in inference results. */ 58 public static final int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0; 59 /** 60 * Do not expect golden outputs with inference inputs. 61 * 62 * Useful in cases where there's no straightforward golden output values 63 * for the benchmark. This will also skip calculating basic (golden 64 * output based) error metrics. 65 */ 66 public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1; 67 runBenchmark(long modelHandle, List<InferenceInOutSequence> inOutList, List<InferenceResult> resultList, int inferencesSeqMaxCount, float timeoutSec, int flags)68 private synchronized native boolean runBenchmark(long modelHandle, 69 List<InferenceInOutSequence> inOutList, 70 List<InferenceResult> resultList, 71 int inferencesSeqMaxCount, 72 float timeoutSec, 73 int flags); 74 dumpAllLayers( long modelHandle, String dumpPath, List<InferenceInOutSequence> inOutList)75 private synchronized native void dumpAllLayers( 76 long modelHandle, 77 String dumpPath, 78 List<InferenceInOutSequence> inOutList); 79 80 protected Context mContext; 81 protected TextView mText; 82 private String mModelName; 83 private String mModelFile; 84 private long mModelHandle; 85 private int[] mInputShape; 86 private InferenceInOutSequence.FromAssets[] mInputOutputAssets; 87 private InferenceInOutSequence.FromDataset[] mInputOutputDatasets; 88 private EvaluatorConfig mEvaluatorConfig; 89 private EvaluatorInterface mEvaluator; 90 private boolean mHasGoldenOutputs; 91 private boolean mUseNNApi = false; 92 private boolean mEnableIntermediateTensorsDump = false; 93 private int mMinSdkVersion; 94 private Optional<String> mNNApiDeviceName = Optional.empty(); 95 NNTestBase(String modelName, String modelFile, int[] inputShape, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets, EvaluatorConfig evaluator, int minSdkVersion)96 public NNTestBase(String modelName, String modelFile, int[] inputShape, 97 InferenceInOutSequence.FromAssets[] inputOutputAssets, 98 InferenceInOutSequence.FromDataset[] inputOutputDatasets, 99 EvaluatorConfig evaluator, int minSdkVersion) { 100 if (inputOutputAssets == null && inputOutputDatasets == null) { 101 throw new IllegalArgumentException( 102 "Neither inputOutputAssets or inputOutputDatasets given - no inputs"); 103 } 104 if (inputOutputAssets != null && inputOutputDatasets != null) { 105 throw new IllegalArgumentException( 106 "Both inputOutputAssets or inputOutputDatasets given. Only one" + 107 "supported at once."); 108 } 109 mModelName = modelName; 110 mModelFile = modelFile; 111 mInputShape = inputShape; 112 mInputOutputAssets = inputOutputAssets; 113 mInputOutputDatasets = inputOutputDatasets; 114 mModelHandle = 0; 115 mEvaluatorConfig = evaluator; 116 mMinSdkVersion = minSdkVersion; 117 } 118 useNNApi()119 public void useNNApi() { 120 useNNApi(true); 121 } 122 useNNApi(boolean value)123 public void useNNApi(boolean value) { 124 mUseNNApi = value; 125 } 126 enableIntermediateTensorsDump()127 public void enableIntermediateTensorsDump() { 128 enableIntermediateTensorsDump(true); 129 } 130 enableIntermediateTensorsDump(boolean value)131 public void enableIntermediateTensorsDump(boolean value) { 132 mEnableIntermediateTensorsDump = value; 133 } 134 setNNApiDeviceName(String value)135 public void setNNApiDeviceName(String value) { 136 if (!mUseNNApi) { 137 Log.e(TAG, "Setting device name has no effect when not using NNAPI"); 138 } 139 mNNApiDeviceName = Optional.ofNullable(value); 140 } 141 setupModel(Context ipcxt)142 public final boolean setupModel(Context ipcxt) { 143 mContext = ipcxt; 144 String modelFileName = copyAssetToFile(); 145 if (modelFileName != null) { 146 mModelHandle = initModel( 147 modelFileName, mUseNNApi, mEnableIntermediateTensorsDump, 148 mNNApiDeviceName.orElse(null)); 149 if (mModelHandle == 0) { 150 Log.e(TAG, "Failed to init the model"); 151 return false; 152 } 153 resizeInputTensors(mModelHandle, mInputShape); 154 } 155 if (mEvaluatorConfig != null) { 156 mEvaluator = mEvaluatorConfig.createEvaluator(mContext.getAssets()); 157 } 158 return true; 159 } 160 getTestInfo()161 public String getTestInfo() { 162 return mModelName; 163 } 164 getEvaluator()165 public EvaluatorInterface getEvaluator() { 166 return mEvaluator; 167 } 168 checkSdkVersion()169 public void checkSdkVersion() throws UnsupportedSdkException { 170 if (mMinSdkVersion > 0 && Build.VERSION.SDK_INT < mMinSdkVersion) { 171 throw new UnsupportedSdkException("SDK version not supported. Mininum required: " + 172 mMinSdkVersion + ", current version: " + Build.VERSION.SDK_INT); 173 } 174 } 175 getInputOutputAssets()176 private List<InferenceInOutSequence> getInputOutputAssets() throws IOException { 177 // TODO: Caching, don't read inputs for every inference 178 List<InferenceInOutSequence> inOutList = new ArrayList<>(); 179 if (mInputOutputAssets != null) { 180 for (InferenceInOutSequence.FromAssets ioAsset : mInputOutputAssets) { 181 inOutList.add(ioAsset.readAssets(mContext.getAssets())); 182 } 183 } 184 if (mInputOutputDatasets != null) { 185 for (InferenceInOutSequence.FromDataset dataset : mInputOutputDatasets) { 186 inOutList.addAll(dataset.readDataset(mContext.getAssets(), 187 mContext.getCacheDir())); 188 } 189 } 190 191 Boolean lastGolden = null; 192 for (InferenceInOutSequence sequence : inOutList) { 193 mHasGoldenOutputs = sequence.hasGoldenOutput(); 194 if (lastGolden == null) { 195 lastGolden = new Boolean(mHasGoldenOutputs); 196 } else { 197 if (lastGolden.booleanValue() != mHasGoldenOutputs) { 198 throw new IllegalArgumentException("Some inputs for " + mModelName + 199 " have outputs while some don't."); 200 } 201 } 202 } 203 return inOutList; 204 } 205 getDefaultFlags()206 public int getDefaultFlags() { 207 int flags = 0; 208 if (!mHasGoldenOutputs) { 209 flags = flags | FLAG_IGNORE_GOLDEN_OUTPUT; 210 } 211 if (mEvaluator == null) { 212 flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT; 213 } 214 return flags; 215 } 216 dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)217 public void dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize) 218 throws IOException { 219 if (!dumpDir.exists() || !dumpDir.isDirectory()) { 220 throw new IllegalArgumentException("dumpDir doesn't exist or is not a directory"); 221 } 222 if (!mEnableIntermediateTensorsDump) { 223 throw new IllegalStateException("mEnableIntermediateTensorsDump is " + 224 "set to false, impossible to proceed"); 225 } 226 227 List<InferenceInOutSequence> ios = getInputOutputAssets(); 228 dumpAllLayers(mModelHandle, dumpDir.toString(), 229 ios.subList(inputAssetIndex, inputAssetSize)); 230 } 231 runInferenceOnce()232 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runInferenceOnce() 233 throws IOException, BenchmarkException { 234 List<InferenceInOutSequence> ios = getInputOutputAssets(); 235 int flags = getDefaultFlags(); 236 Pair<List<InferenceInOutSequence>, List<InferenceResult>> output = 237 runBenchmark(ios, 1, Float.MAX_VALUE, flags); 238 return output; 239 } 240 runBenchmark(float timeoutSec)241 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(float timeoutSec) 242 throws IOException, BenchmarkException { 243 // Run as many as possible before timeout. 244 int flags = getDefaultFlags(); 245 return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags); 246 } 247 248 /** Run through whole input set (once or mutliple times). */ runBenchmarkCompleteInputSet( int setRepeat, float timeoutSec)249 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet( 250 int setRepeat, 251 float timeoutSec) 252 throws IOException, BenchmarkException { 253 int flags = getDefaultFlags(); 254 List<InferenceInOutSequence> ios = getInputOutputAssets(); 255 int totalSequenceInferencesCount = ios.size() * setRepeat; 256 int extpectedResults = 0; 257 for (InferenceInOutSequence iosSeq : ios) { 258 extpectedResults += iosSeq.size(); 259 } 260 extpectedResults *= setRepeat; 261 262 Pair<List<InferenceInOutSequence>, List<InferenceResult>> result = 263 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec, 264 flags); 265 if (result.second.size() != extpectedResults) { 266 // We reached a timeout or failed to evaluate whole set for other reason, abort. 267 final String errorMsg = "Failed to evaluate complete input set, expected: " 268 + extpectedResults + 269 ", received: " + result.second.size(); 270 Log.w(TAG, errorMsg); 271 throw new IllegalStateException(errorMsg); 272 } 273 return result; 274 } 275 runBenchmark( List<InferenceInOutSequence> inOutList, int inferencesSeqMaxCount, float timeoutSec, int flags)276 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark( 277 List<InferenceInOutSequence> inOutList, 278 int inferencesSeqMaxCount, 279 float timeoutSec, 280 int flags) 281 throws IOException, BenchmarkException { 282 if (mModelHandle == 0) { 283 throw new BenchmarkException("Unsupported model"); 284 } 285 List<InferenceResult> resultList = new ArrayList<>(); 286 if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount, 287 timeoutSec, flags)) { 288 throw new BenchmarkException("Failed to run benchmark"); 289 } 290 return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>( 291 inOutList, resultList); 292 } 293 destroy()294 public void destroy() { 295 if (mModelHandle != 0) { 296 destroyModel(mModelHandle); 297 mModelHandle = 0; 298 } 299 } 300 301 // We need to copy it to cache dir, so that TFlite can load it directly. copyAssetToFile()302 private String copyAssetToFile() { 303 String outFileName; 304 String modelAssetName = mModelFile + ".tflite"; 305 AssetManager assetManager = mContext.getAssets(); 306 try { 307 outFileName = mContext.getCacheDir().getAbsolutePath() + "/" + modelAssetName; 308 File outFile = new File(outFileName); 309 310 try (InputStream in = assetManager.open(modelAssetName); 311 FileOutputStream out = new FileOutputStream(outFile)) { 312 313 byte[] byteBuffer = new byte[1024]; 314 int readBytes = -1; 315 while ((readBytes = in.read(byteBuffer)) != -1) { 316 out.write(byteBuffer, 0, readBytes); 317 } 318 } 319 } catch (IOException e) { 320 Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e); 321 return null; 322 } 323 return outFileName; 324 } 325 } 326