1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import static java.util.concurrent.TimeUnit.MILLISECONDS;
20 
21 import android.content.Context;
22 import android.os.Trace;
23 import android.util.Log;
24 import android.util.Pair;
25 
26 import java.io.IOException;
27 import java.util.Collections;
28 import java.util.List;
29 import java.util.concurrent.CountDownLatch;
30 import java.util.concurrent.atomic.AtomicBoolean;
31 
32 /** Processor is a helper thread for running the work without blocking the UI thread. */
33 public class Processor implements Runnable {
34 
35 
36     public interface Callback {
onBenchmarkFinish(boolean ok)37         void onBenchmarkFinish(boolean ok);
38 
onStatusUpdate(int testNumber, int numTests, String modelName)39         void onStatusUpdate(int testNumber, int numTests, String modelName);
40     }
41 
42     protected static final String TAG = "NN_BENCHMARK";
43     private Context mContext;
44 
45     private final AtomicBoolean mRun = new AtomicBoolean(true);
46 
47     volatile boolean mHasBeenStarted = false;
48     // You cannot restart a thread, so the completion flag is final
49     private final CountDownLatch mCompleted = new CountDownLatch(1);
50     private NNTestBase mTest;
51     private int mTestList[];
52     private BenchmarkResult mTestResults[];
53 
54     private Processor.Callback mCallback;
55 
56     private TfLiteBackend mBackend;
57     private boolean mMmapModel;
58     private boolean mCompleteInputSet;
59     private boolean mToggleLong;
60     private boolean mTogglePause;
61     private String mAcceleratorName;
62     private boolean mIgnoreUnsupportedModels;
63     private boolean mRunModelCompilationOnly;
64     // Max number of benchmark iterations to do in run method.
65     // Less or equal to 0 means unlimited
66     private int mMaxRunIterations;
67 
68     private boolean mBenchmarkCompilationCaching;
69     private float mCompilationBenchmarkWarmupTimeSeconds;
70     private float mCompilationBenchmarkRunTimeSeconds;
71     private int mCompilationBenchmarkMaxIterations;
72 
Processor(Context context, Processor.Callback callback, int[] testList)73     public Processor(Context context, Processor.Callback callback, int[] testList) {
74         mContext = context;
75         mCallback = callback;
76         mTestList = testList;
77         if (mTestList != null) {
78             mTestResults = new BenchmarkResult[mTestList.length];
79         }
80         mAcceleratorName = null;
81         mIgnoreUnsupportedModels = false;
82         mRunModelCompilationOnly = false;
83         mMaxRunIterations = 0;
84         mBenchmarkCompilationCaching = false;
85         mBackend = TfLiteBackend.CPU;
86     }
87 
setUseNNApi(boolean useNNApi)88     public void setUseNNApi(boolean useNNApi) {
89         setTfLiteBackend(useNNApi ? TfLiteBackend.NNAPI : TfLiteBackend.CPU);
90     }
91 
setTfLiteBackend(TfLiteBackend backend)92     public void setTfLiteBackend(TfLiteBackend backend) {
93         mBackend = backend;
94     }
95 
setCompleteInputSet(boolean completeInputSet)96     public void setCompleteInputSet(boolean completeInputSet) {
97         mCompleteInputSet = completeInputSet;
98     }
99 
setToggleLong(boolean toggleLong)100     public void setToggleLong(boolean toggleLong) {
101         mToggleLong = toggleLong;
102     }
103 
setTogglePause(boolean togglePause)104     public void setTogglePause(boolean togglePause) {
105         mTogglePause = togglePause;
106     }
107 
setNnApiAcceleratorName(String acceleratorName)108     public void setNnApiAcceleratorName(String acceleratorName) {
109         mAcceleratorName = acceleratorName;
110     }
111 
setIgnoreUnsupportedModels(boolean value)112     public void setIgnoreUnsupportedModels(boolean value) {
113         mIgnoreUnsupportedModels = value;
114     }
115 
setRunModelCompilationOnly(boolean value)116     public void setRunModelCompilationOnly(boolean value) {
117         mRunModelCompilationOnly = value;
118     }
119 
setMmapModel(boolean value)120     public void setMmapModel(boolean value) {
121         mMmapModel = value;
122     }
123 
setMaxRunIterations(int value)124     public void setMaxRunIterations(int value) {
125         mMaxRunIterations = value;
126     }
127 
enableCompilationCachingBenchmarks( float warmupTimeSeconds, float runTimeSeconds, int maxIterations)128     public void enableCompilationCachingBenchmarks(
129             float warmupTimeSeconds, float runTimeSeconds, int maxIterations) {
130         mBenchmarkCompilationCaching = true;
131         mCompilationBenchmarkWarmupTimeSeconds = warmupTimeSeconds;
132         mCompilationBenchmarkRunTimeSeconds = runTimeSeconds;
133         mCompilationBenchmarkMaxIterations = maxIterations;
134     }
135 
getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)136     public BenchmarkResult getInstrumentationResult(
137             TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)
138             throws IOException, BenchmarkException {
139         return getInstrumentationResult(t, warmupTimeSeconds, runTimeSeconds, false);
140     }
141 
142     // Method to retrieve benchmark results for instrumentation tests.
143     // Returns null if the processor is configured to run compilation only
getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults)144     public BenchmarkResult getInstrumentationResult(
145             TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds,
146             boolean sampleResults)
147             throws IOException, BenchmarkException {
148         mTest = changeTest(mTest, t);
149         mTest.setSampleResult(sampleResults);
150         try {
151             BenchmarkResult result = mRunModelCompilationOnly ? null : getBenchmark(
152                     warmupTimeSeconds,
153                     runTimeSeconds);
154             return result;
155         } finally {
156             mTest.destroy();
157             mTest = null;
158         }
159     }
160 
isTestModelSupportedByAccelerator(Context context, TestModels.TestModelEntry testModelEntry, String acceleratorName)161     public static boolean isTestModelSupportedByAccelerator(Context context,
162             TestModels.TestModelEntry testModelEntry, String acceleratorName)
163             throws NnApiDelegationFailure {
164         try (NNTestBase tb = testModelEntry.createNNTestBase(TfLiteBackend.NNAPI,
165                 /*enableIntermediateTensorsDump=*/false,
166                 /*mmapModel=*/ false)) {
167             tb.setNNApiDeviceName(acceleratorName);
168             return tb.setupModel(context);
169         } catch (IOException e) {
170             Log.w(TAG,
171                     String.format("Error trying to check support for model %s on accelerator %s",
172                             testModelEntry.mModelName, acceleratorName), e);
173             return false;
174         } catch (NnApiDelegationFailure nnApiDelegationFailure) {
175             if (nnApiDelegationFailure.getNnApiErrno() == 4 /*ANEURALNETWORKS_BAD_DATA*/) {
176                 // Compilation will fail with ANEURALNETWORKS_BAD_DATA if the device is not
177                 // supporting all operation in the model
178                 return false;
179             }
180 
181             throw nnApiDelegationFailure;
182         }
183     }
184 
changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)185     private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)
186             throws IOException, UnsupportedModelException, NnApiDelegationFailure {
187         if (oldTestBase != null) {
188             // Make sure we don't leak memory.
189             oldTestBase.destroy();
190         }
191         NNTestBase tb = t.createNNTestBase(mBackend, /*enableIntermediateTensorsDump=*/false,
192                 mMmapModel);
193         if (mBackend == TfLiteBackend.NNAPI) {
194             tb.setNNApiDeviceName(mAcceleratorName);
195         }
196         if (!tb.setupModel(mContext)) {
197             throw new UnsupportedModelException("Cannot initialise model");
198         }
199         return tb;
200     }
201 
202     // Run one loop of kernels for at most the specified minimum time.
203     // The function returns the average time in ms for the test run
runBenchmarkLoop(float maxTime, boolean completeInputSet)204     private BenchmarkResult runBenchmarkLoop(float maxTime, boolean completeInputSet)
205             throws IOException {
206         try {
207             // Run the kernel
208             Pair<List<InferenceInOutSequence>, List<InferenceResult>> results;
209             if (maxTime > 0.f) {
210                 if (completeInputSet) {
211                     results = mTest.runBenchmarkCompleteInputSet(1, maxTime);
212                 } else {
213                     results = mTest.runBenchmark(maxTime);
214                 }
215             } else {
216                 results = mTest.runInferenceOnce();
217             }
218             return BenchmarkResult.fromInferenceResults(
219                     mTest.getTestInfo(),
220                     mBackend.toString(),
221                     results.first,
222                     results.second,
223                     mTest.getEvaluator());
224         } catch (BenchmarkException e) {
225             return new BenchmarkResult(e.getMessage());
226         }
227     }
228 
229     // Run one loop of compilations for at least the specified minimum time.
230     // The function will set the compilation results into the provided benchmark result object.
runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime, int maxIterations, BenchmarkResult benchmarkResult)231     private void runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime,
232             int maxIterations, BenchmarkResult benchmarkResult) throws IOException {
233         try {
234             CompilationBenchmarkResult result =
235                     mTest.runCompilationBenchmark(warmupMinTime, runMinTime, maxIterations);
236             benchmarkResult.setCompilationBenchmarkResult(result);
237         } catch (BenchmarkException e) {
238             benchmarkResult.setBenchmarkError(e.getMessage());
239         }
240     }
241 
getTestResults()242     public BenchmarkResult[] getTestResults() {
243         return mTestResults;
244     }
245 
246     // Get a benchmark result for a specific test
getBenchmark(float warmupTimeSeconds, float runTimeSeconds)247     private BenchmarkResult getBenchmark(float warmupTimeSeconds, float runTimeSeconds)
248             throws IOException {
249         try {
250             mTest.checkSdkVersion();
251         } catch (UnsupportedSdkException e) {
252             BenchmarkResult r = new BenchmarkResult(e.getMessage());
253             Log.w(TAG, "Unsupported SDK for test: " + r.toString());
254             return r;
255         }
256 
257         // We run a short bit of work before starting the actual test
258         // this is to let any power management do its job and respond.
259         // For NNAPI systrace usage documentation, see
260         // frameworks/ml/nn/common/include/Tracing.h.
261         try {
262             final String traceName = "[NN_LA_PWU]runBenchmarkLoop";
263             Trace.beginSection(traceName);
264             runBenchmarkLoop(warmupTimeSeconds, false);
265         } finally {
266             Trace.endSection();
267         }
268 
269         // Run the actual benchmark
270         BenchmarkResult r;
271         try {
272             final String traceName = "[NN_LA_PBM]runBenchmarkLoop";
273             Trace.beginSection(traceName);
274             r = runBenchmarkLoop(runTimeSeconds, mCompleteInputSet);
275         } finally {
276             Trace.endSection();
277         }
278 
279         // Compilation benchmark
280         if (mBenchmarkCompilationCaching) {
281             runCompilationBenchmarkLoop(mCompilationBenchmarkWarmupTimeSeconds,
282                     mCompilationBenchmarkRunTimeSeconds, mCompilationBenchmarkMaxIterations, r);
283         }
284 
285         return r;
286     }
287 
288     @Override
run()289     public void run() {
290         mHasBeenStarted = true;
291         Log.d(TAG, "Processor starting");
292         boolean success = true;
293         int benchmarkIterationsCount = 0;
294         try {
295             while (mRun.get()) {
296                 if (mMaxRunIterations > 0 && benchmarkIterationsCount >= mMaxRunIterations) {
297                     break;
298                 }
299                 benchmarkIterationsCount++;
300                 try {
301                     benchmarkAllModels();
302                 } catch (IOException | BenchmarkException e) {
303                     Log.e(TAG, "Exception during benchmark run", e);
304                     success = false;
305                     break;
306                 } catch (Throwable e) {
307                     Log.e(TAG, "Error during execution", e);
308                     throw e;
309                 }
310             }
311             Log.d(TAG, "Processor completed work");
312             mCallback.onBenchmarkFinish(success);
313         } finally {
314             if (mTest != null) {
315                 // Make sure we don't leak memory.
316                 mTest.destroy();
317                 mTest = null;
318             }
319             mCompleted.countDown();
320         }
321     }
322 
benchmarkAllModels()323     private void benchmarkAllModels() throws IOException, BenchmarkException {
324         // Loop over the tests we want to benchmark
325         for (int ct = 0; ct < mTestList.length; ct++) {
326             if (!mRun.get()) {
327                 Log.v(TAG, String.format("Asked to stop execution at model #%d", ct));
328                 break;
329             }
330             // For reproducibility we wait a short time for any sporadic work
331             // created by the user touching the screen to launch the test to pass.
332             // Also allows for things to settle after the test changes.
333             try {
334                 Thread.sleep(250);
335             } catch (InterruptedException ignored) {
336                 Thread.currentThread().interrupt();
337                 break;
338             }
339 
340             TestModels.TestModelEntry testModel =
341                     TestModels.modelsList().get(mTestList[ct]);
342 
343             int testNumber = ct + 1;
344             mCallback.onStatusUpdate(testNumber, mTestList.length,
345                     testModel.toString());
346 
347             // Select the next test
348             try {
349                 mTest = changeTest(mTest, testModel);
350             } catch (UnsupportedModelException e) {
351                 if (mIgnoreUnsupportedModels) {
352                     Log.d(TAG, String.format(
353                             "Cannot initialise test %d: '%s' on accelerator %s, skipping", ct,
354                             testModel.mTestName, mAcceleratorName));
355                 } else {
356                     Log.e(TAG,
357                             String.format("Cannot initialise test %d: '%s'  on accelerator %s.", ct,
358                                     testModel.mTestName, mAcceleratorName), e);
359                     throw e;
360                 }
361             }
362 
363             // If the user selected the "long pause" option, wait
364             if (mTogglePause) {
365                 for (int i = 0; (i < 100) && mRun.get(); i++) {
366                     try {
367                         Thread.sleep(100);
368                     } catch (InterruptedException ignored) {
369                         Thread.currentThread().interrupt();
370                         break;
371                     }
372                 }
373             }
374 
375             if (mRunModelCompilationOnly) {
376                 mTestResults[ct] = BenchmarkResult.fromInferenceResults(testModel.mTestName,
377                         mBackend.toString(),
378                         Collections.emptyList(),
379                         Collections.emptyList(), null);
380             } else {
381                 // Run the test
382                 float warmupTime = 0.3f;
383                 float runTime = 1.f;
384                 if (mToggleLong) {
385                     warmupTime = 2.f;
386                     runTime = 10.f;
387                 }
388                 mTestResults[ct] = getBenchmark(warmupTime, runTime);
389             }
390         }
391     }
392 
exit()393     public void exit() {
394         exitWithTimeout(-1l);
395     }
396 
exitWithTimeout(long timeoutMs)397     public void exitWithTimeout(long timeoutMs) {
398         mRun.set(false);
399 
400         if (mHasBeenStarted) {
401             Log.d(TAG, String.format("Terminating, timeout is %d ms", timeoutMs));
402             try {
403                 if (timeoutMs > 0) {
404                     boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS);
405                     if (!hasCompleted) {
406                         Log.w(TAG, "Exiting before execution actually completed");
407                     }
408                 } else {
409                     mCompleted.await();
410                 }
411             } catch (InterruptedException e) {
412                 Thread.currentThread().interrupt();
413                 Log.w(TAG, "Interrupted while waiting for Processor to complete", e);
414             }
415         }
416 
417         Log.d(TAG, "Done, cleaning up");
418 
419         if (mTest != null) {
420             mTest.destroy();
421             mTest = null;
422         }
423     }
424 }
425