1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import static java.util.concurrent.TimeUnit.MILLISECONDS; 20 21 import android.content.Context; 22 import android.os.Trace; 23 import android.util.Log; 24 import android.util.Pair; 25 26 import java.io.IOException; 27 import java.util.List; 28 import java.util.concurrent.CountDownLatch; 29 import java.util.concurrent.atomic.AtomicBoolean; 30 31 /** Processor is a helper thread for running the work without blocking the UI thread. */ 32 public class Processor implements Runnable { 33 34 public interface Callback { onBenchmarkFinish(boolean ok)35 void onBenchmarkFinish(boolean ok); 36 onStatusUpdate(int testNumber, int numTests, String modelName)37 void onStatusUpdate(int testNumber, int numTests, String modelName); 38 } 39 40 protected static final String TAG = "NN_BENCHMARK"; 41 private Context mContext; 42 43 private final AtomicBoolean mRun = new AtomicBoolean(true); 44 45 volatile boolean mHasBeenStarted = false; 46 // You cannot restart a thread, so the completion flag is final 47 private final CountDownLatch mCompleted = new CountDownLatch(1); 48 private boolean mDoingBenchmark; 49 private NNTestBase mTest; 50 private int mTestList[]; 51 private BenchmarkResult mTestResults[]; 52 53 private Processor.Callback mCallback; 54 55 private boolean mUseNNApi; 56 private boolean mCompleteInputSet; 57 private boolean mToggleLong; 58 private boolean mTogglePause; 59 Processor(Context context, Processor.Callback callback, int[] testList)60 public Processor(Context context, Processor.Callback callback, int[] testList) { 61 mContext = context; 62 mCallback = callback; 63 mTestList = testList; 64 if (mTestList != null) { 65 mTestResults = new BenchmarkResult[mTestList.length]; 66 } 67 } 68 setUseNNApi(boolean useNNApi)69 public void setUseNNApi(boolean useNNApi) { 70 mUseNNApi = useNNApi; 71 } 72 setCompleteInputSet(boolean completeInputSet)73 public void setCompleteInputSet(boolean completeInputSet) { 74 mCompleteInputSet = completeInputSet; 75 } 76 setToggleLong(boolean toggleLong)77 public void setToggleLong(boolean toggleLong) { 78 mToggleLong = toggleLong; 79 } 80 setTogglePause(boolean togglePause)81 public void setTogglePause(boolean togglePause) { 82 mTogglePause = togglePause; 83 } 84 85 // Method to retrieve benchmark results for instrumentation tests. getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)86 public BenchmarkResult getInstrumentationResult( 87 TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds) 88 throws IOException, BenchmarkException { 89 mTest = changeTest(mTest, t); 90 BenchmarkResult result = getBenchmark(warmupTimeSeconds, runTimeSeconds); 91 mTest.destroy(); 92 mTest = null; 93 return result; 94 } 95 changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)96 private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t) 97 throws BenchmarkException { 98 if (oldTestBase != null) { 99 // Make sure we don't leak memory. 100 oldTestBase.destroy(); 101 } 102 NNTestBase tb = t.createNNTestBase(mUseNNApi, false /* enableIntermediateTensorsDump */); 103 if (!tb.setupModel(mContext)) { 104 throw new BenchmarkException("Cannot initialise model"); 105 } 106 return tb; 107 } 108 109 // Run one loop of kernels for at least the specified minimum time. 110 // The function returns the average time in ms for the test run runBenchmarkLoop(float minTime, boolean completeInputSet)111 private BenchmarkResult runBenchmarkLoop(float minTime, boolean completeInputSet) 112 throws IOException { 113 try { 114 // Run the kernel 115 Pair<List<InferenceInOutSequence>, List<InferenceResult>> results; 116 if (minTime > 0.f) { 117 if (completeInputSet) { 118 results = mTest.runBenchmarkCompleteInputSet(1, minTime); 119 } else { 120 results = mTest.runBenchmark(minTime); 121 } 122 } else { 123 results = mTest.runInferenceOnce(); 124 } 125 return BenchmarkResult.fromInferenceResults( 126 mTest.getTestInfo(), 127 mUseNNApi 128 ? BenchmarkResult.BACKEND_TFLITE_NNAPI 129 : BenchmarkResult.BACKEND_TFLITE_CPU, 130 results.first, 131 results.second, 132 mTest.getEvaluator()); 133 } catch (BenchmarkException e) { 134 return new BenchmarkResult(e.getMessage()); 135 } 136 } 137 getTestResults()138 public BenchmarkResult[] getTestResults() { 139 return mTestResults; 140 } 141 142 // Get a benchmark result for a specific test getBenchmark(float warmupTimeSeconds, float runTimeSeconds)143 private BenchmarkResult getBenchmark(float warmupTimeSeconds, float runTimeSeconds) 144 throws IOException { 145 try { 146 mTest.checkSdkVersion(); 147 } catch (UnsupportedSdkException e) { 148 BenchmarkResult r = new BenchmarkResult(e.getMessage()); 149 Log.v(TAG, "Unsupported SDK for test: " + r.toString()); 150 return r; 151 } 152 153 // We run a short bit of work before starting the actual test 154 // this is to let any power management do its job and respond. 155 // For NNAPI systrace usage documentation, see 156 // frameworks/ml/nn/common/include/Tracing.h. 157 try { 158 final String traceName = "[NN_LA_PWU]runBenchmarkLoop"; 159 Trace.beginSection(traceName); 160 runBenchmarkLoop(warmupTimeSeconds, false); 161 } finally { 162 Trace.endSection(); 163 } 164 165 // Run the actual benchmark 166 BenchmarkResult r; 167 try { 168 final String traceName = "[NN_LA_PBM]runBenchmarkLoop"; 169 Trace.beginSection(traceName); 170 r = runBenchmarkLoop(runTimeSeconds, mCompleteInputSet); 171 } finally { 172 Trace.endSection(); 173 } 174 175 Log.v(TAG, "Completed benchmark loop"); 176 177 return r; 178 } 179 180 @Override run()181 public void run() { 182 mHasBeenStarted = true; 183 Log.d(TAG, "Processor starting"); 184 try { 185 while (mRun.get()) { 186 try { 187 benchmarkAllModels(); 188 } catch (IOException e) { 189 Log.e(TAG, "IOException during benchmark run", e); 190 break; 191 } catch (Throwable e) { 192 Log.e(TAG, "Error during execution", e); 193 throw e; 194 } 195 196 mCallback.onBenchmarkFinish(mRun.get()); 197 } 198 } finally { 199 mCompleted.countDown(); 200 } 201 } 202 benchmarkAllModels()203 private void benchmarkAllModels() throws IOException { 204 Log.i(TAG, String.format("Iterating through %d models", mTestList.length)); 205 // Loop over the tests we want to benchmark 206 for (int ct = 0; ct < mTestList.length; ct++) { 207 if (!mRun.get()) { 208 Log.v(TAG, String.format("Asked to stop execution at model #%d", ct)); 209 break; 210 } 211 // For reproducibility we wait a short time for any sporadic work 212 // created by the user touching the screen to launch the test to pass. 213 // Also allows for things to settle after the test changes. 214 try { 215 Thread.sleep(250); 216 } catch (InterruptedException ignored) { 217 Thread.currentThread().interrupt(); 218 break; 219 } 220 221 TestModels.TestModelEntry testModel = 222 TestModels.modelsList().get(mTestList[ct]); 223 224 Log.i(TAG, String.format("%d/%d: '%s'", ct, mTestList.length, 225 testModel.mTestName)); 226 int testNumber = ct + 1; 227 mCallback.onStatusUpdate(testNumber, mTestList.length, 228 testModel.toString()); 229 230 // Select the next test 231 try { 232 mTest = changeTest(mTest, testModel); 233 } catch (BenchmarkException e) { 234 Log.w(TAG, String.format("Cannot initialise test %d: '%s', skipping", ct, 235 testModel.mTestName), e); 236 } 237 238 // If the user selected the "long pause" option, wait 239 if (mTogglePause) { 240 for (int i = 0; (i < 100) && mRun.get(); i++) { 241 try { 242 Thread.sleep(100); 243 } catch (InterruptedException ignored) { 244 Thread.currentThread().interrupt(); 245 break; 246 } 247 } 248 } 249 250 // Run the test 251 float warmupTime = 0.3f; 252 float runTime = 1.f; 253 if (mToggleLong) { 254 warmupTime = 2.f; 255 runTime = 10.f; 256 } 257 Log.i(TAG, "Running test for model " + testModel.mModelName + " file " 258 + testModel.mModelFile); 259 mTestResults[ct] = getBenchmark(warmupTime, runTime); 260 } 261 } 262 exit()263 public void exit() { 264 exitWithTimeout(-1l); 265 } 266 exitWithTimeout(long timeoutMs)267 public void exitWithTimeout(long timeoutMs) { 268 mRun.set(false); 269 270 if (mHasBeenStarted) { 271 try { 272 if (timeoutMs > 0) { 273 boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS); 274 if (!hasCompleted) { 275 Log.w(TAG, "Exiting before execution actually completed"); 276 } 277 } else { 278 mCompleted.await(); 279 } 280 } catch (InterruptedException e) { 281 Thread.currentThread().interrupt(); 282 Log.w(TAG, "Interrupted while waiting for Processor to complete", e); 283 } 284 } 285 286 if (mTest != null) { 287 mTest.destroy(); 288 mTest = null; 289 } 290 } 291 } 292