1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.annotation.SuppressLint; 20 import android.content.Context; 21 import android.content.res.AssetManager; 22 import android.os.Build; 23 import android.util.Log; 24 import android.util.Pair; 25 import android.widget.TextView; 26 27 import java.io.File; 28 import java.io.FileOutputStream; 29 import java.io.IOException; 30 import java.io.InputStream; 31 import java.util.ArrayList; 32 import java.util.Collections; 33 import java.util.List; 34 import java.util.Optional; 35 import java.util.Random; 36 import java.util.stream.Collectors; 37 38 public class NNTestBase implements AutoCloseable { 39 protected static final String TAG = "NN_TESTBASE"; 40 41 // Used to load the 'native-lib' library on application startup. 42 static { 43 System.loadLibrary("nnbenchmark_jni"); 44 } 45 46 // Does the device has any NNAPI accelerator? 47 // We only consider a real device, not 'nnapi-reference'. hasAccelerator()48 public static native boolean hasAccelerator(); 49 50 /** 51 * Fills resultList with the name of the available NNAPI accelerators 52 * 53 * @return False if any error occurred, true otherwise 54 */ getAcceleratorNames(List<String> resultList)55 public static native boolean getAcceleratorNames(List<String> resultList); hasNnApiDevice(String nnApiDeviceName)56 public static native boolean hasNnApiDevice(String nnApiDeviceName); 57 initModel( String modelFileName, int tfliteBackend, boolean enableIntermediateTensorsDump, String nnApiDeviceName, boolean mmapModel, String nnApiCacheDir)58 private synchronized native long initModel( 59 String modelFileName, 60 int tfliteBackend, 61 boolean enableIntermediateTensorsDump, 62 String nnApiDeviceName, 63 boolean mmapModel, 64 String nnApiCacheDir) throws NnApiDelegationFailure; 65 destroyModel(long modelHandle)66 private synchronized native void destroyModel(long modelHandle); 67 resizeInputTensors(long modelHandle, int[] inputShape)68 private synchronized native boolean resizeInputTensors(long modelHandle, int[] inputShape); 69 runBenchmark(long modelHandle, List<InferenceInOutSequence> inOutList, List<InferenceResult> resultList, int inferencesSeqMaxCount, float timeoutSec, int flags)70 private synchronized native boolean runBenchmark(long modelHandle, 71 List<InferenceInOutSequence> inOutList, 72 List<InferenceResult> resultList, 73 int inferencesSeqMaxCount, 74 float timeoutSec, 75 int flags); 76 runCompilationBenchmark( long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec)77 private synchronized native CompilationBenchmarkResult runCompilationBenchmark( 78 long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec); 79 dumpAllLayers( long modelHandle, String dumpPath, List<InferenceInOutSequence> inOutList)80 private synchronized native void dumpAllLayers( 81 long modelHandle, 82 String dumpPath, 83 List<InferenceInOutSequence> inOutList); 84 availableAcceleratorNames()85 public static List<String> availableAcceleratorNames() { 86 List<String> availableAccelerators = new ArrayList<>(); 87 if (NNTestBase.getAcceleratorNames(availableAccelerators)) { 88 return availableAccelerators.stream().filter( 89 acceleratorName -> !acceleratorName.equalsIgnoreCase( 90 "nnapi-reference")).collect(Collectors.toList()); 91 } else { 92 Log.e(TAG, "Unable to retrieve accelerator names!!"); 93 return Collections.EMPTY_LIST; 94 } 95 } 96 97 /** Discard inference output in inference results. */ 98 public static final int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0; 99 /** 100 * Do not expect golden outputs with inference inputs. 101 * 102 * Useful in cases where there's no straightforward golden output values 103 * for the benchmark. This will also skip calculating basic (golden 104 * output based) error metrics. 105 */ 106 public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1; 107 108 109 /** Collect only 1 benchmark result every 10 **/ 110 public static final int FLAG_SAMPLE_BENCHMARK_RESULTS = 1 << 2; 111 112 protected Context mContext; 113 protected TextView mText; 114 private final String mModelName; 115 private final String mModelFile; 116 private long mModelHandle; 117 private final int[] mInputShape; 118 private final InferenceInOutSequence.FromAssets[] mInputOutputAssets; 119 private final InferenceInOutSequence.FromDataset[] mInputOutputDatasets; 120 private final EvaluatorConfig mEvaluatorConfig; 121 private EvaluatorInterface mEvaluator; 122 private boolean mHasGoldenOutputs; 123 private TfLiteBackend mTfLiteBackend; 124 private boolean mEnableIntermediateTensorsDump = false; 125 private final int mMinSdkVersion; 126 private Optional<String> mNNApiDeviceName = Optional.empty(); 127 private boolean mMmapModel = false; 128 // Path where the current model has been stored for execution 129 private String mTemporaryModelFilePath; 130 private boolean mSampleResults; 131 NNTestBase(String modelName, String modelFile, int[] inputShape, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets, EvaluatorConfig evaluator, int minSdkVersion)132 public NNTestBase(String modelName, String modelFile, int[] inputShape, 133 InferenceInOutSequence.FromAssets[] inputOutputAssets, 134 InferenceInOutSequence.FromDataset[] inputOutputDatasets, 135 EvaluatorConfig evaluator, int minSdkVersion) { 136 if (inputOutputAssets == null && inputOutputDatasets == null) { 137 throw new IllegalArgumentException( 138 "Neither inputOutputAssets or inputOutputDatasets given - no inputs"); 139 } 140 if (inputOutputAssets != null && inputOutputDatasets != null) { 141 throw new IllegalArgumentException( 142 "Both inputOutputAssets or inputOutputDatasets given. Only one" + 143 "supported at once."); 144 } 145 mModelName = modelName; 146 mModelFile = modelFile; 147 mInputShape = inputShape; 148 mInputOutputAssets = inputOutputAssets; 149 mInputOutputDatasets = inputOutputDatasets; 150 mModelHandle = 0; 151 mEvaluatorConfig = evaluator; 152 mMinSdkVersion = minSdkVersion; 153 mSampleResults = false; 154 } 155 setTfLiteBackend(TfLiteBackend tfLiteBackend)156 public void setTfLiteBackend(TfLiteBackend tfLiteBackend) { 157 mTfLiteBackend = tfLiteBackend; 158 } 159 enableIntermediateTensorsDump()160 public void enableIntermediateTensorsDump() { 161 enableIntermediateTensorsDump(true); 162 } 163 enableIntermediateTensorsDump(boolean value)164 public void enableIntermediateTensorsDump(boolean value) { 165 mEnableIntermediateTensorsDump = value; 166 } 167 useNNApi()168 public void useNNApi() { 169 setTfLiteBackend(TfLiteBackend.NNAPI); 170 } 171 setNNApiDeviceName(String value)172 public void setNNApiDeviceName(String value) { 173 if (mTfLiteBackend != TfLiteBackend.NNAPI) { 174 Log.e(TAG, "Setting device name has no effect when not using NNAPI"); 175 } 176 mNNApiDeviceName = Optional.ofNullable(value); 177 } 178 setMmapModel(boolean value)179 public void setMmapModel(boolean value) { 180 mMmapModel = value; 181 } 182 setupModel(Context ipcxt)183 public final boolean setupModel(Context ipcxt) throws IOException, NnApiDelegationFailure { 184 mContext = ipcxt; 185 if (mTemporaryModelFilePath != null) { 186 deleteOrWarn(mTemporaryModelFilePath); 187 } 188 mTemporaryModelFilePath = copyAssetToFile(); 189 String nnApiCacheDir = mContext.getCodeCacheDir().toString(); 190 mModelHandle = initModel( 191 mTemporaryModelFilePath, mTfLiteBackend.ordinal(), mEnableIntermediateTensorsDump, 192 mNNApiDeviceName.orElse(null), mMmapModel, nnApiCacheDir); 193 if (mModelHandle == 0) { 194 Log.e(TAG, "Failed to init the model"); 195 return false; 196 } 197 if (!resizeInputTensors(mModelHandle, mInputShape)) { 198 return false; 199 } 200 201 if (mEvaluatorConfig != null) { 202 mEvaluator = mEvaluatorConfig.createEvaluator(mContext.getAssets()); 203 } 204 return true; 205 } 206 getTestInfo()207 public String getTestInfo() { 208 return mModelName; 209 } 210 getEvaluator()211 public EvaluatorInterface getEvaluator() { 212 return mEvaluator; 213 } 214 checkSdkVersion()215 public void checkSdkVersion() throws UnsupportedSdkException { 216 if (mMinSdkVersion > 0 && Build.VERSION.SDK_INT < mMinSdkVersion) { 217 throw new UnsupportedSdkException("SDK version not supported. Mininum required: " + 218 mMinSdkVersion + ", current version: " + Build.VERSION.SDK_INT); 219 } 220 } 221 deleteOrWarn(String path)222 private void deleteOrWarn(String path) { 223 if (!new File(path).delete()) { 224 Log.w(TAG, String.format( 225 "Unable to delete file '%s'. This might cause device to run out of space.", 226 path)); 227 } 228 } 229 230 getInputOutputAssets()231 private List<InferenceInOutSequence> getInputOutputAssets() throws IOException { 232 // TODO: Caching, don't read inputs for every inference 233 List<InferenceInOutSequence> inOutList = 234 getInputOutputAssets(mContext, mInputOutputAssets, mInputOutputDatasets); 235 236 Boolean lastGolden = null; 237 for (InferenceInOutSequence sequence : inOutList) { 238 mHasGoldenOutputs = sequence.hasGoldenOutput(); 239 if (lastGolden == null) { 240 lastGolden = mHasGoldenOutputs; 241 } else { 242 if (lastGolden != mHasGoldenOutputs) { 243 throw new IllegalArgumentException( 244 "Some inputs for " + mModelName + " have outputs while some don't."); 245 } 246 } 247 } 248 return inOutList; 249 } 250 getInputOutputAssets(Context context, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets)251 public static List<InferenceInOutSequence> getInputOutputAssets(Context context, 252 InferenceInOutSequence.FromAssets[] inputOutputAssets, 253 InferenceInOutSequence.FromDataset[] inputOutputDatasets) throws IOException { 254 // TODO: Caching, don't read inputs for every inference 255 List<InferenceInOutSequence> inOutList = new ArrayList<>(); 256 if (inputOutputAssets != null) { 257 for (InferenceInOutSequence.FromAssets ioAsset : inputOutputAssets) { 258 inOutList.add(ioAsset.readAssets(context.getAssets())); 259 } 260 } 261 if (inputOutputDatasets != null) { 262 for (InferenceInOutSequence.FromDataset dataset : inputOutputDatasets) { 263 inOutList.addAll(dataset.readDataset(context.getAssets(), context.getCacheDir())); 264 } 265 } 266 267 return inOutList; 268 } 269 getDefaultFlags()270 public int getDefaultFlags() { 271 int flags = 0; 272 if (!mHasGoldenOutputs) { 273 flags = flags | FLAG_IGNORE_GOLDEN_OUTPUT; 274 } 275 if (mEvaluator == null) { 276 flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT; 277 } 278 // For very long tests we will collect only a sample of the results 279 if (mSampleResults) { 280 flags = flags | FLAG_SAMPLE_BENCHMARK_RESULTS; 281 } 282 return flags; 283 } 284 dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)285 public void dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize) 286 throws IOException { 287 if (!dumpDir.exists() || !dumpDir.isDirectory()) { 288 throw new IllegalArgumentException("dumpDir doesn't exist or is not a directory"); 289 } 290 if (!mEnableIntermediateTensorsDump) { 291 throw new IllegalStateException("mEnableIntermediateTensorsDump is " + 292 "set to false, impossible to proceed"); 293 } 294 295 List<InferenceInOutSequence> ios = getInputOutputAssets(); 296 dumpAllLayers(mModelHandle, dumpDir.toString(), 297 ios.subList(inputAssetIndex, inputAssetSize)); 298 } 299 runInferenceOnce()300 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runInferenceOnce() 301 throws IOException, BenchmarkException { 302 List<InferenceInOutSequence> ios = getInputOutputAssets(); 303 int flags = getDefaultFlags(); 304 Pair<List<InferenceInOutSequence>, List<InferenceResult>> output = 305 runBenchmark(ios, 1, Float.MAX_VALUE, flags); 306 return output; 307 } 308 runBenchmark(float timeoutSec)309 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(float timeoutSec) 310 throws IOException, BenchmarkException { 311 // Run as many as possible before timeout. 312 int flags = getDefaultFlags(); 313 return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags); 314 } 315 316 /** Run through whole input set (once or multiple times). */ runBenchmarkCompleteInputSet( int minInferences, float timeoutSec)317 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet( 318 int minInferences, 319 float timeoutSec) 320 throws IOException, BenchmarkException { 321 int flags = getDefaultFlags(); 322 List<InferenceInOutSequence> ios = getInputOutputAssets(); 323 int setInferences = 0; 324 for (InferenceInOutSequence iosSeq : ios) { 325 setInferences += iosSeq.size(); 326 } 327 int setRepeat = (minInferences + setInferences - 1) / setInferences; // ceil. 328 int totalSequenceInferencesCount = ios.size() * setRepeat; 329 int expectedResults = setInferences * setRepeat; 330 331 Pair<List<InferenceInOutSequence>, List<InferenceResult>> result = 332 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec, 333 flags); 334 if (result.second.size() != expectedResults) { 335 // We reached a timeout or failed to evaluate whole set for other reason, abort. 336 @SuppressLint("DefaultLocale") 337 final String errorMsg = String.format( 338 "Failed to evaluate complete input set, in %f seconds expected: %d, received:" 339 + " %d", 340 timeoutSec, expectedResults, result.second.size()); 341 Log.w(TAG, errorMsg); 342 throw new IllegalStateException(errorMsg); 343 } 344 return result; 345 } 346 runBenchmark( List<InferenceInOutSequence> inOutList, int inferencesSeqMaxCount, float timeoutSec, int flags)347 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark( 348 List<InferenceInOutSequence> inOutList, 349 int inferencesSeqMaxCount, 350 float timeoutSec, 351 int flags) 352 throws IOException, BenchmarkException { 353 if (mModelHandle == 0) { 354 throw new UnsupportedModelException("Unsupported model"); 355 } 356 List<InferenceResult> resultList = new ArrayList<>(); 357 if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount, 358 timeoutSec, flags)) { 359 throw new BenchmarkException("Failed to run benchmark"); 360 } 361 return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>( 362 inOutList, resultList); 363 } 364 runCompilationBenchmark(float warmupTimeoutSec, float runTimeoutSec, int maxIterations)365 public CompilationBenchmarkResult runCompilationBenchmark(float warmupTimeoutSec, 366 float runTimeoutSec, int maxIterations) throws IOException, BenchmarkException { 367 if (mModelHandle == 0) { 368 throw new UnsupportedModelException("Unsupported model"); 369 } 370 CompilationBenchmarkResult result = runCompilationBenchmark( 371 mModelHandle, maxIterations, warmupTimeoutSec, runTimeoutSec); 372 if (result == null) { 373 throw new BenchmarkException("Failed to run compilation benchmark"); 374 } 375 return result; 376 } 377 destroy()378 public void destroy() { 379 if (mModelHandle != 0) { 380 destroyModel(mModelHandle); 381 mModelHandle = 0; 382 } 383 if (mTemporaryModelFilePath != null) { 384 deleteOrWarn(mTemporaryModelFilePath); 385 mTemporaryModelFilePath = null; 386 } 387 } 388 389 private final Random mRandom = new Random(System.currentTimeMillis()); 390 391 // We need to copy it to cache dir, so that TFlite can load it directly. copyAssetToFile()392 private String copyAssetToFile() throws IOException { 393 @SuppressLint("DefaultLocale") 394 String outFileName = 395 String.format("%s/%s-%d-%d.tflite", mContext.getCacheDir().getAbsolutePath(), 396 mModelFile, 397 Thread.currentThread().getId(), mRandom.nextInt(10000)); 398 399 copyAssetToFile(mContext, mModelFile + ".tflite", outFileName); 400 return outFileName; 401 } 402 copyModelToFile(Context context, String modelFileName, File targetFile)403 public static boolean copyModelToFile(Context context, String modelFileName, File targetFile) 404 throws IOException { 405 if (!targetFile.exists() && !targetFile.createNewFile()) { 406 Log.w(TAG, String.format("Unable to create file %s", targetFile.getAbsolutePath())); 407 return false; 408 } 409 NNTestBase.copyAssetToFile(context, modelFileName, targetFile.getAbsolutePath()); 410 return true; 411 } 412 copyAssetToFile(Context context, String modelAssetName, String targetPath)413 public static void copyAssetToFile(Context context, String modelAssetName, String targetPath) 414 throws IOException { 415 AssetManager assetManager = context.getAssets(); 416 try { 417 File outFile = new File(targetPath); 418 419 try (InputStream in = assetManager.open(modelAssetName); 420 FileOutputStream out = new FileOutputStream(outFile)) { 421 byte[] byteBuffer = new byte[1024]; 422 int readBytes = -1; 423 while ((readBytes = in.read(byteBuffer)) != -1) { 424 out.write(byteBuffer, 0, readBytes); 425 } 426 } 427 } catch (IOException e) { 428 Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e); 429 throw e; 430 } 431 } 432 433 @Override close()434 public void close() { 435 destroy(); 436 } 437 setSampleResult(boolean sampleResults)438 public void setSampleResult(boolean sampleResults) { 439 this.mSampleResults = sampleResults; 440 } 441 } 442