1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.app;
18 
19 import android.annotation.SuppressLint;
20 import android.app.Activity;
21 import android.content.Intent;
22 import android.os.Bundle;
23 import android.util.Log;
24 import android.view.WindowManager;
25 import android.widget.TextView;
26 import com.android.nn.benchmark.core.BenchmarkException;
27 import com.android.nn.benchmark.core.BenchmarkResult;
28 import com.android.nn.benchmark.core.Processor;
29 import com.android.nn.benchmark.core.TestModels.TestModelEntry;
30 import com.android.nn.benchmark.core.TfLiteBackend;
31 import java.io.IOException;
32 import java.time.Duration;
33 import java.util.concurrent.ExecutorService;
34 import java.util.concurrent.Executors;
35 
36 public class NNBenchmark extends Activity implements Processor.Callback {
37     public static final String TAG = "NN_BENCHMARK";
38 
39     public static final String EXTRA_ENABLE_LONG = "enable long";
40     public static final String EXTRA_ENABLE_PAUSE = "enable pause";
41     public static final String EXTRA_DISABLE_NNAPI = "disable NNAPI";
42     public static final String EXTRA_TESTS = "tests";
43 
44     public static final String EXTRA_RESULTS_TESTS = "tests";
45     public static final String EXTRA_RESULTS_RESULTS = "results";
46     public static final long PROCESSOR_TERMINATION_TIMEOUT_MS = Duration.ofSeconds(20).toMillis();
47     public static final String EXTRA_MAX_ITERATIONS = "max_iterations";
48 
49     private int mTestList[];
50 
51     private Processor mProcessor;
52     private final ExecutorService executorService = Executors.newSingleThreadExecutor();
53 
54     private TextView mTextView;
55 
56     // Initialize the parameters for Instrumentation tests.
prepareInstrumentationTest()57     protected void prepareInstrumentationTest() {
58         mTestList = new int[1];
59         mProcessor = new Processor(this, this, mTestList);
60     }
61 
setUseNNApi(boolean useNNApi)62     public void setUseNNApi(boolean useNNApi) {
63         mProcessor.setTfLiteBackend(useNNApi ? TfLiteBackend.NNAPI : TfLiteBackend.CPU);
64     }
65 
setNnApiAcceleratorName(String acceleratorName)66     public void setNnApiAcceleratorName(String acceleratorName) {
67         mProcessor.setNnApiAcceleratorName(acceleratorName);
68     }
69 
setCompleteInputSet(boolean completeInputSet)70     public void setCompleteInputSet(boolean completeInputSet) {
71         mProcessor.setCompleteInputSet(completeInputSet);
72     }
73 
enableCompilationCachingBenchmarks( float warmupTimeSeconds, float runTimeSeconds, int maxIterations)74     public void enableCompilationCachingBenchmarks(
75             float warmupTimeSeconds, float runTimeSeconds, int maxIterations) {
76         mProcessor.enableCompilationCachingBenchmarks(
77                 warmupTimeSeconds, runTimeSeconds, maxIterations);
78     }
79 
80     @SuppressLint("SetTextI18n")
81     @Override
onCreate(Bundle savedInstanceState)82     protected void onCreate(Bundle savedInstanceState) {
83         super.onCreate(savedInstanceState);
84         mTextView = new TextView(this);
85         mTextView.setTextSize(20);
86         mTextView.setText("Running NN benchmark...");
87         setContentView(mTextView);
88         getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
89     }
90 
91     @Override
onPause()92     protected void onPause() {
93         super.onPause();
94         if (mProcessor != null) {
95             mProcessor.exitWithTimeout(PROCESSOR_TERMINATION_TIMEOUT_MS);
96             mProcessor = null;
97         }
98     }
99 
onBenchmarkFinish(boolean ok)100     public void onBenchmarkFinish(boolean ok) {
101         if (ok) {
102             Intent intent = new Intent();
103             intent.putExtra(EXTRA_RESULTS_TESTS, mTestList);
104             intent.putExtra(EXTRA_RESULTS_RESULTS, mProcessor.getTestResults());
105             setResult(RESULT_OK, intent);
106         } else {
107             setResult(RESULT_CANCELED);
108         }
109         finish();
110     }
111 
112     @SuppressLint("DefaultLocale")
onStatusUpdate(int testNumber, int numTests, String modelName)113     public void onStatusUpdate(int testNumber, int numTests, String modelName) {
114         runOnUiThread(
115                 () -> {
116                     mTextView.setText(
117                             String.format(
118                                     "Running test %d of %d: %s", testNumber, numTests, modelName));
119                 });
120     }
121 
122     @Override
onResume()123     protected void onResume() {
124         super.onResume();
125         Intent i = getIntent();
126         mTestList = i.getIntArrayExtra(EXTRA_TESTS);
127         if (mTestList != null && mTestList.length > 0) {
128             Log.v(TAG, String.format("Starting benchmark with %d test", mTestList.length));
129             mProcessor = new Processor(this, this, mTestList);
130             mProcessor.setToggleLong(i.getBooleanExtra(EXTRA_ENABLE_LONG, false));
131             mProcessor.setTogglePause(i.getBooleanExtra(EXTRA_ENABLE_PAUSE, false));
132             mProcessor.setTfLiteBackend(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false) ? TfLiteBackend.NNAPI : TfLiteBackend.CPU);
133             mProcessor.setMaxRunIterations(i.getIntExtra(EXTRA_MAX_ITERATIONS, 0));
134             executorService.submit(mProcessor);
135         } else {
136             Log.v(TAG, "No test to run, doing nothing");
137         }
138     }
139 
140     @Override
onDestroy()141     protected void onDestroy() {
142         super.onDestroy();
143     }
144 
runSynchronously(TestModelEntry testModel, float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults)145     public BenchmarkResult runSynchronously(TestModelEntry testModel,
146         float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults) throws IOException, BenchmarkException {
147         return mProcessor.getInstrumentationResult(testModel, warmupTimeSeconds, runTimeSeconds, sampleResults);
148     }
149 }
150