1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.app.Activity; 20 import android.content.res.AssetManager; 21 import android.os.Build; 22 import android.util.Log; 23 import android.util.Pair; 24 import android.widget.TextView; 25 26 import java.io.File; 27 import java.io.FileOutputStream; 28 import java.io.IOException; 29 import java.io.InputStream; 30 import java.io.OutputStream; 31 import java.util.ArrayList; 32 import java.util.List; 33 import java.util.Optional; 34 35 public class NNTestBase { 36 protected static final String TAG = "NN_TESTBASE"; 37 38 // Used to load the 'native-lib' library on application startup. 39 static { 40 System.loadLibrary("nnbenchmark_jni"); 41 } 42 initModel( String modelFileName, boolean useNNApi, boolean enableIntermediateTensorsDump, String nnApiDeviceName)43 private synchronized native long initModel( 44 String modelFileName, 45 boolean useNNApi, 46 boolean enableIntermediateTensorsDump, 47 String nnApiDeviceName); 48 destroyModel(long modelHandle)49 private synchronized native void destroyModel(long modelHandle); 50 resizeInputTensors(long modelHandle, int[] inputShape)51 private synchronized native boolean resizeInputTensors(long modelHandle, int[] inputShape); 52 53 /** Discard inference output in inference results. */ 54 public static final int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0; 55 /** 56 * Do not expect golden outputs with inference inputs. 57 * 58 * Useful in cases where there's no straightforward golden output values 59 * for the benchmark. This will also skip calculating basic (golden 60 * output based) error metrics. 61 */ 62 public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1; 63 runBenchmark(long modelHandle, List<InferenceInOutSequence> inOutList, List<InferenceResult> resultList, int inferencesSeqMaxCount, float timeoutSec, int flags)64 private synchronized native boolean runBenchmark(long modelHandle, 65 List<InferenceInOutSequence> inOutList, 66 List<InferenceResult> resultList, 67 int inferencesSeqMaxCount, 68 float timeoutSec, 69 int flags); 70 dumpAllLayers( long modelHandle, String dumpPath, List<InferenceInOutSequence> inOutList)71 private synchronized native void dumpAllLayers( 72 long modelHandle, 73 String dumpPath, 74 List<InferenceInOutSequence> inOutList); 75 76 protected Activity mActivity; 77 protected TextView mText; 78 private String mModelName; 79 private String mModelFile; 80 private long mModelHandle; 81 private int[] mInputShape; 82 private InferenceInOutSequence.FromAssets[] mInputOutputAssets; 83 private InferenceInOutSequence.FromDataset[] mInputOutputDatasets; 84 private EvaluatorConfig mEvaluatorConfig; 85 private EvaluatorInterface mEvaluator; 86 private boolean mHasGoldenOutputs; 87 private boolean mUseNNApi = false; 88 private boolean mEnableIntermediateTensorsDump = false; 89 private int mMinSdkVersion; 90 private Optional<String> mNNApiDeviceName = Optional.empty(); 91 NNTestBase(String modelName, String modelFile, int[] inputShape, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets, EvaluatorConfig evaluator, int minSdkVersion)92 public NNTestBase(String modelName, String modelFile, int[] inputShape, 93 InferenceInOutSequence.FromAssets[] inputOutputAssets, 94 InferenceInOutSequence.FromDataset[] inputOutputDatasets, 95 EvaluatorConfig evaluator, int minSdkVersion) { 96 if (inputOutputAssets == null && inputOutputDatasets == null) { 97 throw new IllegalArgumentException( 98 "Neither inputOutputAssets or inputOutputDatasets given - no inputs"); 99 } 100 if (inputOutputAssets != null && inputOutputDatasets != null) { 101 throw new IllegalArgumentException( 102 "Both inputOutputAssets or inputOutputDatasets given. Only one" + 103 "supported at once."); 104 } 105 mModelName = modelName; 106 mModelFile = modelFile; 107 mInputShape = inputShape; 108 mInputOutputAssets = inputOutputAssets; 109 mInputOutputDatasets = inputOutputDatasets; 110 mModelHandle = 0; 111 mEvaluatorConfig = evaluator; 112 mMinSdkVersion = minSdkVersion; 113 } 114 useNNApi()115 public void useNNApi() { 116 useNNApi(true); 117 } 118 useNNApi(boolean value)119 public void useNNApi(boolean value) { 120 mUseNNApi = value; 121 } 122 enableIntermediateTensorsDump()123 public void enableIntermediateTensorsDump() { 124 enableIntermediateTensorsDump(true); 125 } 126 enableIntermediateTensorsDump(boolean value)127 public void enableIntermediateTensorsDump(boolean value) { 128 mEnableIntermediateTensorsDump = value; 129 } 130 setNNApiDeviceName(String value)131 public void setNNApiDeviceName(String value) { 132 if (!mUseNNApi) { 133 Log.e(TAG, "Setting device name has no effect when not using NNAPI"); 134 } 135 mNNApiDeviceName = Optional.ofNullable(value); 136 } 137 setupModel(Activity ipact)138 public final boolean setupModel(Activity ipact) { 139 mActivity = ipact; 140 String modelFileName = copyAssetToFile(); 141 if (modelFileName != null) { 142 mModelHandle = initModel( 143 modelFileName, mUseNNApi, mEnableIntermediateTensorsDump, 144 mNNApiDeviceName.orElse(null)); 145 if (mModelHandle == 0) { 146 Log.e(TAG, "Failed to init the model"); 147 return false; 148 } 149 resizeInputTensors(mModelHandle, mInputShape); 150 } 151 if (mEvaluatorConfig != null) { 152 mEvaluator = mEvaluatorConfig.createEvaluator(mActivity.getAssets()); 153 } 154 return true; 155 } 156 getTestInfo()157 public String getTestInfo() { 158 return mModelName; 159 } 160 getEvaluator()161 public EvaluatorInterface getEvaluator() { 162 return mEvaluator; 163 } 164 checkSdkVersion()165 public void checkSdkVersion() throws UnsupportedSdkException { 166 if (mMinSdkVersion > 0 && Build.VERSION.SDK_INT < mMinSdkVersion) { 167 throw new UnsupportedSdkException("SDK version not supported. Mininum required: " + 168 mMinSdkVersion + ", current version: " + Build.VERSION.SDK_INT); 169 } 170 } 171 getInputOutputAssets()172 private List<InferenceInOutSequence> getInputOutputAssets() throws IOException { 173 // TODO: Caching, don't read inputs for every inference 174 List<InferenceInOutSequence> inOutList = new ArrayList<>(); 175 if (mInputOutputAssets != null) { 176 for (InferenceInOutSequence.FromAssets ioAsset : mInputOutputAssets) { 177 inOutList.add(ioAsset.readAssets(mActivity.getAssets())); 178 } 179 } 180 if (mInputOutputDatasets != null) { 181 for (InferenceInOutSequence.FromDataset dataset : mInputOutputDatasets) { 182 inOutList.addAll(dataset.readDataset(mActivity.getAssets(), 183 mActivity.getCacheDir())); 184 } 185 } 186 187 Boolean lastGolden = null; 188 for (InferenceInOutSequence sequence : inOutList) { 189 mHasGoldenOutputs = sequence.hasGoldenOutput(); 190 if (lastGolden == null) { 191 lastGolden = new Boolean(mHasGoldenOutputs); 192 } else { 193 if (lastGolden.booleanValue() != mHasGoldenOutputs) { 194 throw new IllegalArgumentException("Some inputs for " + mModelName + 195 " have outputs while some don't."); 196 } 197 } 198 } 199 return inOutList; 200 } 201 getDefaultFlags()202 public int getDefaultFlags() { 203 int flags = 0; 204 if (!mHasGoldenOutputs) { 205 flags = flags | FLAG_IGNORE_GOLDEN_OUTPUT; 206 } 207 if (mEvaluator == null) { 208 flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT; 209 } 210 return flags; 211 } 212 dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)213 public void dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize) 214 throws IOException { 215 if (!dumpDir.exists() || !dumpDir.isDirectory()) { 216 throw new IllegalArgumentException("dumpDir doesn't exist or is not a directory"); 217 } 218 if (!mEnableIntermediateTensorsDump) { 219 throw new IllegalStateException("mEnableIntermediateTensorsDump is " + 220 "set to false, impossible to proceed"); 221 } 222 223 List<InferenceInOutSequence> ios = getInputOutputAssets(); 224 dumpAllLayers(mModelHandle, dumpDir.toString(), 225 ios.subList(inputAssetIndex, inputAssetSize)); 226 } 227 runInferenceOnce()228 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runInferenceOnce() 229 throws IOException, BenchmarkException { 230 List<InferenceInOutSequence> ios = getInputOutputAssets(); 231 int flags = getDefaultFlags(); 232 Pair<List<InferenceInOutSequence>, List<InferenceResult>> output = 233 runBenchmark(ios, 1, Float.MAX_VALUE, flags); 234 return output; 235 } 236 runBenchmark(float timeoutSec)237 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(float timeoutSec) 238 throws IOException, BenchmarkException { 239 // Run as many as possible before timeout. 240 int flags = getDefaultFlags(); 241 return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags); 242 } 243 244 /** Run through whole input set (once or mutliple times). */ runBenchmarkCompleteInputSet( int setRepeat, float timeoutSec)245 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet( 246 int setRepeat, 247 float timeoutSec) 248 throws IOException, BenchmarkException { 249 int flags = getDefaultFlags(); 250 List<InferenceInOutSequence> ios = getInputOutputAssets(); 251 int totalSequenceInferencesCount = ios.size() * setRepeat; 252 int extpectedResults = 0; 253 for (InferenceInOutSequence iosSeq : ios) { 254 extpectedResults += iosSeq.size(); 255 } 256 extpectedResults *= setRepeat; 257 258 Pair<List<InferenceInOutSequence>, List<InferenceResult>> result = 259 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec, 260 flags); 261 if (result.second.size() != extpectedResults ) { 262 // We reached a timeout or failed to evaluate whole set for other reason, abort. 263 throw new IllegalStateException( 264 "Failed to evaluate complete input set, expected: " 265 + extpectedResults + 266 ", received: " + result.second.size()); 267 } 268 return result; 269 } 270 runBenchmark( List<InferenceInOutSequence> inOutList, int inferencesSeqMaxCount, float timeoutSec, int flags)271 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark( 272 List<InferenceInOutSequence> inOutList, 273 int inferencesSeqMaxCount, 274 float timeoutSec, 275 int flags) 276 throws IOException, BenchmarkException { 277 if (mModelHandle == 0) { 278 throw new BenchmarkException("Unsupported model"); 279 } 280 List<InferenceResult> resultList = new ArrayList<>(); 281 if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount, 282 timeoutSec, flags)) { 283 throw new BenchmarkException("Failed to run benchmark"); 284 } 285 return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>( 286 inOutList, resultList); 287 } 288 destroy()289 public void destroy() { 290 if (mModelHandle != 0) { 291 destroyModel(mModelHandle); 292 mModelHandle = 0; 293 } 294 } 295 296 // We need to copy it to cache dir, so that TFlite can load it directly. copyAssetToFile()297 private String copyAssetToFile() { 298 String outFileName; 299 String modelAssetName = mModelFile + ".tflite"; 300 AssetManager assetManager = mActivity.getAssets(); 301 try { 302 InputStream in = assetManager.open(modelAssetName); 303 304 outFileName = mActivity.getCacheDir().getAbsolutePath() + "/" + modelAssetName; 305 File outFile = new File(outFileName); 306 OutputStream out = new FileOutputStream(outFile); 307 308 byte[] buffer = new byte[1024]; 309 int read; 310 while ((read = in.read(buffer)) != -1) { 311 out.write(buffer, 0, read); 312 } 313 out.flush(); 314 315 in.close(); 316 out.close(); 317 } catch (IOException e) { 318 Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e); 319 return null; 320 } 321 return outFileName; 322 } 323 } 324