1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import static java.util.concurrent.TimeUnit.MILLISECONDS;
20 
21 import android.content.Context;
22 import android.os.Trace;
23 import android.util.Log;
24 import android.util.Pair;
25 
26 import java.io.IOException;
27 import java.util.List;
28 import java.util.concurrent.CountDownLatch;
29 import java.util.concurrent.atomic.AtomicBoolean;
30 
31 /** Processor is a helper thread for running the work without blocking the UI thread. */
32 public class Processor implements Runnable {
33 
34     public interface Callback {
onBenchmarkFinish(boolean ok)35         void onBenchmarkFinish(boolean ok);
36 
onStatusUpdate(int testNumber, int numTests, String modelName)37         void onStatusUpdate(int testNumber, int numTests, String modelName);
38     }
39 
40     protected static final String TAG = "NN_BENCHMARK";
41     private Context mContext;
42 
43     private final AtomicBoolean mRun = new AtomicBoolean(true);
44 
45     volatile boolean mHasBeenStarted = false;
46     // You cannot restart a thread, so the completion flag is final
47     private final CountDownLatch mCompleted = new CountDownLatch(1);
48     private boolean mDoingBenchmark;
49     private NNTestBase mTest;
50     private int mTestList[];
51     private BenchmarkResult mTestResults[];
52 
53     private Processor.Callback mCallback;
54 
55     private boolean mUseNNApi;
56     private boolean mCompleteInputSet;
57     private boolean mToggleLong;
58     private boolean mTogglePause;
59 
Processor(Context context, Processor.Callback callback, int[] testList)60     public Processor(Context context, Processor.Callback callback, int[] testList) {
61         mContext = context;
62         mCallback = callback;
63         mTestList = testList;
64         if (mTestList != null) {
65             mTestResults = new BenchmarkResult[mTestList.length];
66         }
67     }
68 
setUseNNApi(boolean useNNApi)69     public void setUseNNApi(boolean useNNApi) {
70         mUseNNApi = useNNApi;
71     }
72 
setCompleteInputSet(boolean completeInputSet)73     public void setCompleteInputSet(boolean completeInputSet) {
74         mCompleteInputSet = completeInputSet;
75     }
76 
setToggleLong(boolean toggleLong)77     public void setToggleLong(boolean toggleLong) {
78         mToggleLong = toggleLong;
79     }
80 
setTogglePause(boolean togglePause)81     public void setTogglePause(boolean togglePause) {
82         mTogglePause = togglePause;
83     }
84 
85     // Method to retrieve benchmark results for instrumentation tests.
getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)86     public BenchmarkResult getInstrumentationResult(
87             TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)
88             throws IOException, BenchmarkException {
89         mTest = changeTest(mTest, t);
90         BenchmarkResult result = getBenchmark(warmupTimeSeconds, runTimeSeconds);
91         mTest.destroy();
92         mTest = null;
93         return result;
94     }
95 
changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)96     private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)
97             throws BenchmarkException {
98         if (oldTestBase != null) {
99             // Make sure we don't leak memory.
100             oldTestBase.destroy();
101         }
102         NNTestBase tb = t.createNNTestBase(mUseNNApi, false /* enableIntermediateTensorsDump */);
103         if (!tb.setupModel(mContext)) {
104             throw new BenchmarkException("Cannot initialise model");
105         }
106         return tb;
107     }
108 
109     // Run one loop of kernels for at least the specified minimum time.
110     // The function returns the average time in ms for the test run
runBenchmarkLoop(float minTime, boolean completeInputSet)111     private BenchmarkResult runBenchmarkLoop(float minTime, boolean completeInputSet)
112             throws IOException {
113         try {
114             // Run the kernel
115             Pair<List<InferenceInOutSequence>, List<InferenceResult>> results;
116             if (minTime > 0.f) {
117                 if (completeInputSet) {
118                     results = mTest.runBenchmarkCompleteInputSet(1, minTime);
119                 } else {
120                     results = mTest.runBenchmark(minTime);
121                 }
122             } else {
123                 results = mTest.runInferenceOnce();
124             }
125             return BenchmarkResult.fromInferenceResults(
126                     mTest.getTestInfo(),
127                     mUseNNApi
128                             ? BenchmarkResult.BACKEND_TFLITE_NNAPI
129                             : BenchmarkResult.BACKEND_TFLITE_CPU,
130                     results.first,
131                     results.second,
132                     mTest.getEvaluator());
133         } catch (BenchmarkException e) {
134             return new BenchmarkResult(e.getMessage());
135         }
136     }
137 
getTestResults()138     public BenchmarkResult[] getTestResults() {
139         return mTestResults;
140     }
141 
142     // Get a benchmark result for a specific test
getBenchmark(float warmupTimeSeconds, float runTimeSeconds)143     private BenchmarkResult getBenchmark(float warmupTimeSeconds, float runTimeSeconds)
144             throws IOException {
145         try {
146             mTest.checkSdkVersion();
147         } catch (UnsupportedSdkException e) {
148             BenchmarkResult r = new BenchmarkResult(e.getMessage());
149             Log.v(TAG, "Unsupported SDK for test: " + r.toString());
150             return r;
151         }
152 
153         // We run a short bit of work before starting the actual test
154         // this is to let any power management do its job and respond.
155         // For NNAPI systrace usage documentation, see
156         // frameworks/ml/nn/common/include/Tracing.h.
157         try {
158             final String traceName = "[NN_LA_PWU]runBenchmarkLoop";
159             Trace.beginSection(traceName);
160             runBenchmarkLoop(warmupTimeSeconds, false);
161         } finally {
162             Trace.endSection();
163         }
164 
165         // Run the actual benchmark
166         BenchmarkResult r;
167         try {
168             final String traceName = "[NN_LA_PBM]runBenchmarkLoop";
169             Trace.beginSection(traceName);
170             r = runBenchmarkLoop(runTimeSeconds, mCompleteInputSet);
171         } finally {
172             Trace.endSection();
173         }
174 
175         Log.v(TAG, "Completed benchmark loop");
176 
177         return r;
178     }
179 
180     @Override
run()181     public void run() {
182         mHasBeenStarted = true;
183         Log.d(TAG, "Processor starting");
184         try {
185             while (mRun.get()) {
186                 try {
187                     benchmarkAllModels();
188                 } catch (IOException e) {
189                     Log.e(TAG, "IOException during benchmark run", e);
190                     break;
191                 } catch (Throwable e) {
192                     Log.e(TAG, "Error during execution", e);
193                     throw e;
194                 }
195 
196                 mCallback.onBenchmarkFinish(mRun.get());
197             }
198         } finally {
199             mCompleted.countDown();
200         }
201     }
202 
benchmarkAllModels()203     private void benchmarkAllModels() throws IOException {
204         Log.i(TAG, String.format("Iterating through %d models", mTestList.length));
205         // Loop over the tests we want to benchmark
206         for (int ct = 0; ct < mTestList.length; ct++) {
207             if (!mRun.get()) {
208                 Log.v(TAG, String.format("Asked to stop execution at model #%d", ct));
209                 break;
210             }
211             // For reproducibility we wait a short time for any sporadic work
212             // created by the user touching the screen to launch the test to pass.
213             // Also allows for things to settle after the test changes.
214             try {
215                 Thread.sleep(250);
216             } catch (InterruptedException ignored) {
217                 Thread.currentThread().interrupt();
218                 break;
219             }
220 
221             TestModels.TestModelEntry testModel =
222                     TestModels.modelsList().get(mTestList[ct]);
223 
224             Log.i(TAG, String.format("%d/%d: '%s'", ct, mTestList.length,
225                     testModel.mTestName));
226             int testNumber = ct + 1;
227             mCallback.onStatusUpdate(testNumber, mTestList.length,
228                     testModel.toString());
229 
230             // Select the next test
231             try {
232                 mTest = changeTest(mTest, testModel);
233             } catch (BenchmarkException e) {
234                 Log.w(TAG, String.format("Cannot initialise test %d: '%s', skipping", ct,
235                         testModel.mTestName), e);
236             }
237 
238             // If the user selected the "long pause" option, wait
239             if (mTogglePause) {
240                 for (int i = 0; (i < 100) && mRun.get(); i++) {
241                     try {
242                         Thread.sleep(100);
243                     } catch (InterruptedException ignored) {
244                         Thread.currentThread().interrupt();
245                         break;
246                     }
247                 }
248             }
249 
250             // Run the test
251             float warmupTime = 0.3f;
252             float runTime = 1.f;
253             if (mToggleLong) {
254                 warmupTime = 2.f;
255                 runTime = 10.f;
256             }
257             Log.i(TAG, "Running test for model " + testModel.mModelName + " file "
258                     + testModel.mModelFile);
259             mTestResults[ct] = getBenchmark(warmupTime, runTime);
260         }
261     }
262 
exit()263     public void exit() {
264         exitWithTimeout(-1l);
265     }
266 
exitWithTimeout(long timeoutMs)267     public void exitWithTimeout(long timeoutMs) {
268         mRun.set(false);
269 
270         if (mHasBeenStarted) {
271             try {
272                 if (timeoutMs > 0) {
273                     boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS);
274                     if (!hasCompleted) {
275                         Log.w(TAG, "Exiting before execution actually completed");
276                     }
277                 } else {
278                     mCompleted.await();
279                 }
280             } catch (InterruptedException e) {
281                 Thread.currentThread().interrupt();
282                 Log.w(TAG, "Interrupted while waiting for Processor to complete", e);
283             }
284         }
285 
286         if (mTest != null) {
287             mTest.destroy();
288             mTest = null;
289         }
290     }
291 }
292