1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import static java.util.concurrent.TimeUnit.MILLISECONDS; 20 21 import android.content.Context; 22 import android.os.Trace; 23 import android.util.Log; 24 import android.util.Pair; 25 26 import java.io.IOException; 27 import java.util.Collections; 28 import java.util.List; 29 import java.util.concurrent.CountDownLatch; 30 import java.util.concurrent.atomic.AtomicBoolean; 31 32 /** Processor is a helper thread for running the work without blocking the UI thread. */ 33 public class Processor implements Runnable { 34 35 36 public interface Callback { onBenchmarkFinish(boolean ok)37 void onBenchmarkFinish(boolean ok); 38 onStatusUpdate(int testNumber, int numTests, String modelName)39 void onStatusUpdate(int testNumber, int numTests, String modelName); 40 } 41 42 protected static final String TAG = "NN_BENCHMARK"; 43 private Context mContext; 44 45 private final AtomicBoolean mRun = new AtomicBoolean(true); 46 47 volatile boolean mHasBeenStarted = false; 48 // You cannot restart a thread, so the completion flag is final 49 private final CountDownLatch mCompleted = new CountDownLatch(1); 50 private NNTestBase mTest; 51 private int mTestList[]; 52 private BenchmarkResult mTestResults[]; 53 54 private Processor.Callback mCallback; 55 56 private TfLiteBackend mBackend; 57 private boolean mMmapModel; 58 private boolean mCompleteInputSet; 59 private boolean mToggleLong; 60 private boolean mTogglePause; 61 private String mAcceleratorName; 62 private boolean mIgnoreUnsupportedModels; 63 private boolean mRunModelCompilationOnly; 64 // Max number of benchmark iterations to do in run method. 65 // Less or equal to 0 means unlimited 66 private int mMaxRunIterations; 67 68 private boolean mBenchmarkCompilationCaching; 69 private float mCompilationBenchmarkWarmupTimeSeconds; 70 private float mCompilationBenchmarkRunTimeSeconds; 71 private int mCompilationBenchmarkMaxIterations; 72 Processor(Context context, Processor.Callback callback, int[] testList)73 public Processor(Context context, Processor.Callback callback, int[] testList) { 74 mContext = context; 75 mCallback = callback; 76 mTestList = testList; 77 if (mTestList != null) { 78 mTestResults = new BenchmarkResult[mTestList.length]; 79 } 80 mAcceleratorName = null; 81 mIgnoreUnsupportedModels = false; 82 mRunModelCompilationOnly = false; 83 mMaxRunIterations = 0; 84 mBenchmarkCompilationCaching = false; 85 mBackend = TfLiteBackend.CPU; 86 } 87 setUseNNApi(boolean useNNApi)88 public void setUseNNApi(boolean useNNApi) { 89 setTfLiteBackend(useNNApi ? TfLiteBackend.NNAPI : TfLiteBackend.CPU); 90 } 91 setTfLiteBackend(TfLiteBackend backend)92 public void setTfLiteBackend(TfLiteBackend backend) { 93 mBackend = backend; 94 } 95 setCompleteInputSet(boolean completeInputSet)96 public void setCompleteInputSet(boolean completeInputSet) { 97 mCompleteInputSet = completeInputSet; 98 } 99 setToggleLong(boolean toggleLong)100 public void setToggleLong(boolean toggleLong) { 101 mToggleLong = toggleLong; 102 } 103 setTogglePause(boolean togglePause)104 public void setTogglePause(boolean togglePause) { 105 mTogglePause = togglePause; 106 } 107 setNnApiAcceleratorName(String acceleratorName)108 public void setNnApiAcceleratorName(String acceleratorName) { 109 mAcceleratorName = acceleratorName; 110 } 111 setIgnoreUnsupportedModels(boolean value)112 public void setIgnoreUnsupportedModels(boolean value) { 113 mIgnoreUnsupportedModels = value; 114 } 115 setRunModelCompilationOnly(boolean value)116 public void setRunModelCompilationOnly(boolean value) { 117 mRunModelCompilationOnly = value; 118 } 119 setMmapModel(boolean value)120 public void setMmapModel(boolean value) { 121 mMmapModel = value; 122 } 123 setMaxRunIterations(int value)124 public void setMaxRunIterations(int value) { 125 mMaxRunIterations = value; 126 } 127 enableCompilationCachingBenchmarks( float warmupTimeSeconds, float runTimeSeconds, int maxIterations)128 public void enableCompilationCachingBenchmarks( 129 float warmupTimeSeconds, float runTimeSeconds, int maxIterations) { 130 mBenchmarkCompilationCaching = true; 131 mCompilationBenchmarkWarmupTimeSeconds = warmupTimeSeconds; 132 mCompilationBenchmarkRunTimeSeconds = runTimeSeconds; 133 mCompilationBenchmarkMaxIterations = maxIterations; 134 } 135 getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)136 public BenchmarkResult getInstrumentationResult( 137 TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds) 138 throws IOException, BenchmarkException { 139 return getInstrumentationResult(t, warmupTimeSeconds, runTimeSeconds, false); 140 } 141 142 // Method to retrieve benchmark results for instrumentation tests. 143 // Returns null if the processor is configured to run compilation only getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults)144 public BenchmarkResult getInstrumentationResult( 145 TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds, 146 boolean sampleResults) 147 throws IOException, BenchmarkException { 148 mTest = changeTest(mTest, t); 149 mTest.setSampleResult(sampleResults); 150 try { 151 BenchmarkResult result = mRunModelCompilationOnly ? null : getBenchmark( 152 warmupTimeSeconds, 153 runTimeSeconds); 154 return result; 155 } finally { 156 mTest.destroy(); 157 mTest = null; 158 } 159 } 160 isTestModelSupportedByAccelerator(Context context, TestModels.TestModelEntry testModelEntry, String acceleratorName)161 public static boolean isTestModelSupportedByAccelerator(Context context, 162 TestModels.TestModelEntry testModelEntry, String acceleratorName) 163 throws NnApiDelegationFailure { 164 try (NNTestBase tb = testModelEntry.createNNTestBase(TfLiteBackend.NNAPI, 165 /*enableIntermediateTensorsDump=*/false, 166 /*mmapModel=*/ false)) { 167 tb.setNNApiDeviceName(acceleratorName); 168 return tb.setupModel(context); 169 } catch (IOException e) { 170 Log.w(TAG, 171 String.format("Error trying to check support for model %s on accelerator %s", 172 testModelEntry.mModelName, acceleratorName), e); 173 return false; 174 } catch (NnApiDelegationFailure nnApiDelegationFailure) { 175 if (nnApiDelegationFailure.getNnApiErrno() == 4 /*ANEURALNETWORKS_BAD_DATA*/) { 176 // Compilation will fail with ANEURALNETWORKS_BAD_DATA if the device is not 177 // supporting all operation in the model 178 return false; 179 } 180 181 throw nnApiDelegationFailure; 182 } 183 } 184 changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)185 private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t) 186 throws IOException, UnsupportedModelException, NnApiDelegationFailure { 187 if (oldTestBase != null) { 188 // Make sure we don't leak memory. 189 oldTestBase.destroy(); 190 } 191 NNTestBase tb = t.createNNTestBase(mBackend, /*enableIntermediateTensorsDump=*/false, 192 mMmapModel); 193 if (mBackend == TfLiteBackend.NNAPI) { 194 tb.setNNApiDeviceName(mAcceleratorName); 195 } 196 if (!tb.setupModel(mContext)) { 197 throw new UnsupportedModelException("Cannot initialise model"); 198 } 199 return tb; 200 } 201 202 // Run one loop of kernels for at most the specified minimum time. 203 // The function returns the average time in ms for the test run runBenchmarkLoop(float maxTime, boolean completeInputSet)204 private BenchmarkResult runBenchmarkLoop(float maxTime, boolean completeInputSet) 205 throws IOException { 206 try { 207 // Run the kernel 208 Pair<List<InferenceInOutSequence>, List<InferenceResult>> results; 209 if (maxTime > 0.f) { 210 if (completeInputSet) { 211 results = mTest.runBenchmarkCompleteInputSet(1, maxTime); 212 } else { 213 results = mTest.runBenchmark(maxTime); 214 } 215 } else { 216 results = mTest.runInferenceOnce(); 217 } 218 return BenchmarkResult.fromInferenceResults( 219 mTest.getTestInfo(), 220 mBackend.toString(), 221 results.first, 222 results.second, 223 mTest.getEvaluator()); 224 } catch (BenchmarkException e) { 225 return new BenchmarkResult(e.getMessage()); 226 } 227 } 228 229 // Run one loop of compilations for at least the specified minimum time. 230 // The function will set the compilation results into the provided benchmark result object. runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime, int maxIterations, BenchmarkResult benchmarkResult)231 private void runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime, 232 int maxIterations, BenchmarkResult benchmarkResult) throws IOException { 233 try { 234 CompilationBenchmarkResult result = 235 mTest.runCompilationBenchmark(warmupMinTime, runMinTime, maxIterations); 236 benchmarkResult.setCompilationBenchmarkResult(result); 237 } catch (BenchmarkException e) { 238 benchmarkResult.setBenchmarkError(e.getMessage()); 239 } 240 } 241 getTestResults()242 public BenchmarkResult[] getTestResults() { 243 return mTestResults; 244 } 245 246 // Get a benchmark result for a specific test getBenchmark(float warmupTimeSeconds, float runTimeSeconds)247 private BenchmarkResult getBenchmark(float warmupTimeSeconds, float runTimeSeconds) 248 throws IOException { 249 try { 250 mTest.checkSdkVersion(); 251 } catch (UnsupportedSdkException e) { 252 BenchmarkResult r = new BenchmarkResult(e.getMessage()); 253 Log.w(TAG, "Unsupported SDK for test: " + r.toString()); 254 return r; 255 } 256 257 // We run a short bit of work before starting the actual test 258 // this is to let any power management do its job and respond. 259 // For NNAPI systrace usage documentation, see 260 // frameworks/ml/nn/common/include/Tracing.h. 261 try { 262 final String traceName = "[NN_LA_PWU]runBenchmarkLoop"; 263 Trace.beginSection(traceName); 264 runBenchmarkLoop(warmupTimeSeconds, false); 265 } finally { 266 Trace.endSection(); 267 } 268 269 // Run the actual benchmark 270 BenchmarkResult r; 271 try { 272 final String traceName = "[NN_LA_PBM]runBenchmarkLoop"; 273 Trace.beginSection(traceName); 274 r = runBenchmarkLoop(runTimeSeconds, mCompleteInputSet); 275 } finally { 276 Trace.endSection(); 277 } 278 279 // Compilation benchmark 280 if (mBenchmarkCompilationCaching) { 281 runCompilationBenchmarkLoop(mCompilationBenchmarkWarmupTimeSeconds, 282 mCompilationBenchmarkRunTimeSeconds, mCompilationBenchmarkMaxIterations, r); 283 } 284 285 return r; 286 } 287 288 @Override run()289 public void run() { 290 mHasBeenStarted = true; 291 Log.d(TAG, "Processor starting"); 292 boolean success = true; 293 int benchmarkIterationsCount = 0; 294 try { 295 while (mRun.get()) { 296 if (mMaxRunIterations > 0 && benchmarkIterationsCount >= mMaxRunIterations) { 297 break; 298 } 299 benchmarkIterationsCount++; 300 try { 301 benchmarkAllModels(); 302 } catch (IOException | BenchmarkException e) { 303 Log.e(TAG, "Exception during benchmark run", e); 304 success = false; 305 break; 306 } catch (Throwable e) { 307 Log.e(TAG, "Error during execution", e); 308 throw e; 309 } 310 } 311 Log.d(TAG, "Processor completed work"); 312 mCallback.onBenchmarkFinish(success); 313 } finally { 314 if (mTest != null) { 315 // Make sure we don't leak memory. 316 mTest.destroy(); 317 mTest = null; 318 } 319 mCompleted.countDown(); 320 } 321 } 322 benchmarkAllModels()323 private void benchmarkAllModels() throws IOException, BenchmarkException { 324 // Loop over the tests we want to benchmark 325 for (int ct = 0; ct < mTestList.length; ct++) { 326 if (!mRun.get()) { 327 Log.v(TAG, String.format("Asked to stop execution at model #%d", ct)); 328 break; 329 } 330 // For reproducibility we wait a short time for any sporadic work 331 // created by the user touching the screen to launch the test to pass. 332 // Also allows for things to settle after the test changes. 333 try { 334 Thread.sleep(250); 335 } catch (InterruptedException ignored) { 336 Thread.currentThread().interrupt(); 337 break; 338 } 339 340 TestModels.TestModelEntry testModel = 341 TestModels.modelsList().get(mTestList[ct]); 342 343 int testNumber = ct + 1; 344 mCallback.onStatusUpdate(testNumber, mTestList.length, 345 testModel.toString()); 346 347 // Select the next test 348 try { 349 mTest = changeTest(mTest, testModel); 350 } catch (UnsupportedModelException e) { 351 if (mIgnoreUnsupportedModels) { 352 Log.d(TAG, String.format( 353 "Cannot initialise test %d: '%s' on accelerator %s, skipping", ct, 354 testModel.mTestName, mAcceleratorName)); 355 } else { 356 Log.e(TAG, 357 String.format("Cannot initialise test %d: '%s' on accelerator %s.", ct, 358 testModel.mTestName, mAcceleratorName), e); 359 throw e; 360 } 361 } 362 363 // If the user selected the "long pause" option, wait 364 if (mTogglePause) { 365 for (int i = 0; (i < 100) && mRun.get(); i++) { 366 try { 367 Thread.sleep(100); 368 } catch (InterruptedException ignored) { 369 Thread.currentThread().interrupt(); 370 break; 371 } 372 } 373 } 374 375 if (mRunModelCompilationOnly) { 376 mTestResults[ct] = BenchmarkResult.fromInferenceResults(testModel.mTestName, 377 mBackend.toString(), 378 Collections.emptyList(), 379 Collections.emptyList(), null); 380 } else { 381 // Run the test 382 float warmupTime = 0.3f; 383 float runTime = 1.f; 384 if (mToggleLong) { 385 warmupTime = 2.f; 386 runTime = 10.f; 387 } 388 mTestResults[ct] = getBenchmark(warmupTime, runTime); 389 } 390 } 391 } 392 exit()393 public void exit() { 394 exitWithTimeout(-1l); 395 } 396 exitWithTimeout(long timeoutMs)397 public void exitWithTimeout(long timeoutMs) { 398 mRun.set(false); 399 400 if (mHasBeenStarted) { 401 Log.d(TAG, String.format("Terminating, timeout is %d ms", timeoutMs)); 402 try { 403 if (timeoutMs > 0) { 404 boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS); 405 if (!hasCompleted) { 406 Log.w(TAG, "Exiting before execution actually completed"); 407 } 408 } else { 409 mCompleted.await(); 410 } 411 } catch (InterruptedException e) { 412 Thread.currentThread().interrupt(); 413 Log.w(TAG, "Interrupted while waiting for Processor to complete", e); 414 } 415 } 416 417 Log.d(TAG, "Done, cleaning up"); 418 419 if (mTest != null) { 420 mTest.destroy(); 421 mTest = null; 422 } 423 } 424 } 425