1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.app;
18 
19 import android.app.Activity;
20 import android.content.Intent;
21 import android.os.Bundle;
22 import android.os.Trace;
23 import android.util.Log;
24 import android.util.Pair;
25 import android.view.WindowManager;
26 import android.widget.TextView;
27 
28 import com.android.nn.benchmark.core.BenchmarkException;
29 import com.android.nn.benchmark.core.BenchmarkResult;
30 import com.android.nn.benchmark.core.InferenceInOutSequence;
31 import com.android.nn.benchmark.core.InferenceResult;
32 import com.android.nn.benchmark.core.NNTestBase;
33 import com.android.nn.benchmark.core.TestModels;
34 import com.android.nn.benchmark.core.UnsupportedSdkException;
35 
36 import java.util.List;
37 import java.io.IOException;
38 
39 public class NNBenchmark extends Activity {
40     protected static final String TAG = "NN_BENCHMARK";
41 
42     public static final String EXTRA_ENABLE_LONG = "enable long";
43     public static final String EXTRA_ENABLE_PAUSE = "enable pause";
44     public static final String EXTRA_DISABLE_NNAPI = "disable NNAPI";
45     public static final String EXTRA_DEMO = "demo";
46     public static final String EXTRA_TESTS = "tests";
47 
48     public static final String EXTRA_RESULTS_TESTS = "tests";
49     public static final String EXTRA_RESULTS_RESULTS = "results";
50 
51     private int mTestList[];
52     private BenchmarkResult mTestResults[];
53 
54     private TextView mTextView;
55     private boolean mToggleLong;
56     private boolean mTogglePause;
57 
58     private boolean mUseNNApi;
59     private boolean mCompleteInputSet;
60 
setUseNNApi(boolean useNNApi)61     protected void setUseNNApi(boolean useNNApi) {
62         mUseNNApi = useNNApi;
63     }
64 
setCompleteInputSet(boolean completeInputSet)65     protected void setCompleteInputSet(boolean completeInputSet) {
66         mCompleteInputSet = completeInputSet;
67     }
68 
69     // Initialize the parameters for Instrumentation tests.
prepareInstrumentationTest()70     protected void prepareInstrumentationTest() {
71         mTestList = new int[1];
72         mTestResults = new BenchmarkResult[1];
73         mProcessor = new Processor();
74     }
75 
76     /////////////////////////////////////////////////////////////////////////
77     // Processor is a helper thread for running the work without
78     // blocking the UI thread.
79     class Processor extends Thread {
80         private float mLastResult;
81         private boolean mRun = true;
82         private boolean mDoingBenchmark;
83         private NNTestBase mTest;
84 
85         // Method to retrieve benchmark results for instrumentation tests.
getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)86         BenchmarkResult getInstrumentationResult(
87                 TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)
88                 throws IOException {
89             mTest = changeTest(t);
90             return getBenchmark(warmupTimeSeconds, runTimeSeconds);
91         }
92 
93         // Run one loop of kernels for at least the specified minimum time.
94         // The function returns the average time in ms for the test run
runBenchmarkLoop(float minTime, boolean completeInputSet)95         private BenchmarkResult runBenchmarkLoop(float minTime, boolean completeInputSet)
96                 throws IOException {
97             try {
98                 // Run the kernel
99                 Pair<List<InferenceInOutSequence>, List<InferenceResult>> results;
100                 if (minTime > 0.f) {
101                     if (completeInputSet) {
102                         results = mTest.runBenchmarkCompleteInputSet(1, minTime);
103                     } else {
104                         results = mTest.runBenchmark(minTime);
105                     }
106                 } else {
107                     results = mTest.runInferenceOnce();
108                 }
109                 return BenchmarkResult.fromInferenceResults(
110                         mTest.getTestInfo(),
111                         mUseNNApi ? BenchmarkResult.BACKEND_TFLITE_NNAPI
112                                 : BenchmarkResult.BACKEND_TFLITE_CPU,
113                         results.first, results.second, mTest.getEvaluator());
114             } catch (BenchmarkException e) {
115                 return new BenchmarkResult(e.getMessage());
116             }
117         }
118 
119 
120         // Get a benchmark result for a specific test
getBenchmark(float warmupTimeSeconds, float runTimeSeconds)121         private BenchmarkResult getBenchmark(float warmupTimeSeconds, float runTimeSeconds)
122             throws IOException {
123             try {
124                 mTest.checkSdkVersion();
125             } catch (UnsupportedSdkException e) {
126                 return new BenchmarkResult(e.getMessage());
127             }
128 
129             mDoingBenchmark = true;
130 
131             long result = 0;
132 
133             // We run a short bit of work before starting the actual test
134             // this is to let any power management do its job and respond.
135             // For NNAPI systrace usage documentation, see
136             // frameworks/ml/nn/common/include/Tracing.h.
137             try {
138                 final String traceName = "[NN_LA_PWU]runBenchmarkLoop";
139                 Trace.beginSection(traceName);
140                 runBenchmarkLoop(warmupTimeSeconds, false);
141             } finally {
142                 Trace.endSection();
143             }
144 
145             // Run the actual benchmark
146             BenchmarkResult r;
147             try {
148                 final String traceName = "[NN_LA_PBM]runBenchmarkLoop";
149                 Trace.beginSection(traceName);
150                 r = runBenchmarkLoop(runTimeSeconds, mCompleteInputSet);
151             } finally {
152                 Trace.endSection();
153             }
154 
155             Log.v(TAG, "Test: " + r.toString());
156 
157             mDoingBenchmark = false;
158             return r;
159         }
160 
161         @Override
run()162         public void run() {
163             while (mRun) {
164                 // Our loop for launching tests or benchmarks
165                 synchronized (this) {
166                     // We may have been asked to exit while waiting
167                     if (!mRun) return;
168                 }
169 
170                 try {
171                     // Loop over the tests we want to benchmark
172                     for (int ct = 0; (ct < mTestList.length) && mRun; ct++) {
173 
174                         // For reproducibility we wait a short time for any sporadic work
175                         // created by the user touching the screen to launch the test to pass.
176                         // Also allows for things to settle after the test changes.
177                         try {
178                             sleep(250);
179                         } catch (InterruptedException e) {
180                         }
181 
182                         // If we just ran a test, we destroy it here to relieve some memory
183                         // pressure
184 
185                         if (mTest != null) {
186                             mTest.destroy();
187                         }
188 
189                         TestModels.TestModelEntry testModel =
190                             TestModels.modelsList().get(mTestList[ct]);
191                         int testNumber = ct + 1;
192                         runOnUiThread(() -> {
193                             mTextView.setText(
194                                 String.format(
195                                     "Running test %d of %d: %s",
196                                     testNumber,
197                                     mTestList.length,
198                                     testModel.toString()));
199                         });
200 
201                         // Select the next test
202                         mTest = changeTest(testModel);
203 
204                         // If the user selected the "long pause" option, wait
205                         if (mTogglePause) {
206                             for (int i = 0; (i < 100) && mRun; i++) {
207                                 try {
208                                     sleep(100);
209                                 } catch (InterruptedException e) {
210                                 }
211                             }
212                         }
213 
214                         // Run the test
215                         float warmupTime = 0.3f;
216                         float runTime = 1.f;
217                         if (mToggleLong) {
218                             warmupTime = 2.f;
219                             runTime = 10.f;
220                         }
221                         mTestResults[ct] = getBenchmark(warmupTime, runTime);
222                     }
223                     onBenchmarkFinish(mRun);
224                 } catch (IOException e) {
225                     Log.e(TAG, "Exception during benchmark run", e);
226                     break;
227                 }
228             }
229         }
230 
exit()231         public void exit() {
232             mRun = false;
233 
234             synchronized (this) {
235                 notifyAll();
236             }
237 
238             try {
239                 this.join();
240             } catch (InterruptedException e) {
241             }
242 
243             if (mTest != null) {
244                 mTest.destroy();
245                 mTest = null;
246             }
247         }
248     }
249 
250 
251     private boolean mDoingBenchmark;
252     public Processor mProcessor;
253 
changeTest(TestModels.TestModelEntry t)254     NNTestBase changeTest(TestModels.TestModelEntry t) {
255         NNTestBase tb = t.createNNTestBase(mUseNNApi,
256                 false /* enableIntermediateTensorsDump */);
257         tb.setupModel(this);
258         return tb;
259     }
260 
261     @Override
onCreate(Bundle savedInstanceState)262     protected void onCreate(Bundle savedInstanceState) {
263         super.onCreate(savedInstanceState);
264         mTextView = new TextView(this);
265         mTextView.setTextSize(20);
266         mTextView.setText("Running NN benchmark...");
267         setContentView(mTextView);
268         getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
269     }
270 
271     @Override
onPause()272     protected void onPause() {
273         super.onPause();
274         if (mProcessor != null) {
275             mProcessor.exit();
276         }
277     }
278 
onBenchmarkFinish(boolean ok)279     public void onBenchmarkFinish(boolean ok) {
280         if (ok) {
281             Intent intent = new Intent();
282             intent.putExtra(EXTRA_RESULTS_TESTS, mTestList);
283             intent.putExtra(EXTRA_RESULTS_RESULTS, mTestResults);
284             setResult(RESULT_OK, intent);
285         } else {
286             setResult(RESULT_CANCELED);
287         }
288         finish();
289     }
290 
291     @Override
onResume()292     protected void onResume() {
293         super.onResume();
294         Intent i = getIntent();
295         mTestList = i.getIntArrayExtra(EXTRA_TESTS);
296         mToggleLong = i.getBooleanExtra(EXTRA_ENABLE_LONG, false);
297         mTogglePause = i.getBooleanExtra(EXTRA_ENABLE_PAUSE, false);
298         setUseNNApi(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false));
299 
300         if (mTestList != null) {
301             mTestResults = new BenchmarkResult[mTestList.length];
302             mProcessor = new Processor();
303             mProcessor.start();
304         }
305     }
306 
307     @Override
onDestroy()308     protected void onDestroy() {
309         super.onDestroy();
310     }
311 }
312