1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.crashtest;
18 
19 import android.annotation.SuppressLint;
20 import android.os.Bundle;
21 import android.util.Log;
22 import android.view.View;
23 import android.widget.AdapterView;
24 import android.widget.ArrayAdapter;
25 import android.widget.Button;
26 import android.widget.CheckBox;
27 import android.widget.NumberPicker;
28 import android.widget.Spinner;
29 import android.widget.TextView;
30 
31 import androidx.appcompat.app.AppCompatActivity;
32 
33 import com.android.nn.benchmark.core.NNTestBase;
34 import com.android.nn.benchmark.core.NnApiDelegationFailure;
35 import com.android.nn.benchmark.core.Processor;
36 import com.android.nn.benchmark.core.TestModels;
37 import com.android.nn.benchmark.core.TestModelsListLoader;
38 import com.android.nn.crashtest.core.CrashTestCoordinator;
39 import com.android.nn.crashtest.core.test.RunModelsInParallel;
40 import com.android.nn.benchmark.util.TestExternalStorageActivity;
41 
42 import java.io.ByteArrayOutputStream;
43 import java.io.IOException;
44 import java.io.PrintStream;
45 import java.time.Duration;
46 import java.util.ArrayList;
47 import java.util.List;
48 import java.util.Optional;
49 import java.util.concurrent.atomic.AtomicBoolean;
50 import java.util.concurrent.atomic.AtomicInteger;
51 import java.util.concurrent.atomic.AtomicReference;
52 import java.util.stream.Collectors;
53 
54 public class MainActivity extends AppCompatActivity {
55 
56     private static final String ALL_AVAILABLE_ACCELERATORS = "All available";
57     private static final String TAG = "NN_STRESS_TEST";
58     private static final int JOB_FREQUENCY_MILLIS = 15 * 60 * 1000; // 15 minutes
59     private final AtomicInteger mSelectedModelIndex = new AtomicInteger(-1);
60     private final AtomicBoolean mUseSeparateProcess = new AtomicBoolean(true);
61     private final AtomicReference<String> mAcceleratorName = new AtomicReference<>(null);
62     AtomicBoolean mTestRunning = new AtomicBoolean(false);
63     private Button mStartTestButton;
64     private TextView mMessage;
65     private NumberPicker mThreadCount;
66     private NumberPicker mTestDurationMinutes;
67     private ArrayAdapter<String> mModelsAdapter;
68     private List<String> mAllTestModels;
69     private CheckBox mMmapModel;
70     private CheckBox mCompileModelsOnly;
71 
72     @Override
onCreate(Bundle savedInstanceState)73     protected void onCreate(Bundle savedInstanceState) {
74         super.onCreate(savedInstanceState);
75         setContentView(R.layout.activity_main);
76         TestExternalStorageActivity.testWriteExternalStorage(this, true);
77 
78         mStartTestButton = (Button) findViewById(R.id.start_button);
79         mMessage = (TextView) findViewById(R.id.message);
80 
81         try {
82             TestModelsListLoader.parseFromAssets(getAssets());
83         } catch (IOException e) {
84             Log.e(TAG, "Could not load models", e);
85         }
86 
87         mAllTestModels = TestModels.modelsList().stream().map(
88                 TestModels.TestModelEntry::getTestName).collect(
89                 Collectors.toList());
90 
91         final List<String> modelNames = new ArrayList<>();
92         modelNames.add("All Models");
93         modelNames.addAll(modelsForAccelerator(ALL_AVAILABLE_ACCELERATORS));
94         mModelsAdapter = new ArrayAdapter<String>(this,
95                 android.R.layout.simple_spinner_item,
96                 modelNames);
97         mModelsAdapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item);
98         final Spinner testModelSpinner = (Spinner) findViewById(R.id.test_model);
99         testModelSpinner.setAdapter(mModelsAdapter);
100         testModelSpinner.setOnItemSelectedListener(new AdapterView.OnItemSelectedListener() {
101             @Override
102             public void onItemSelected(AdapterView<?> parent, View view, int position, long id) {
103                 mSelectedModelIndex.set((int) id - 1);
104             }
105 
106             @Override
107             public void onNothingSelected(AdapterView<?> parent) {
108                 mSelectedModelIndex.set(-1);
109             }
110         });
111 
112         final List<String> acceleratorNames = new ArrayList<>();
113         acceleratorNames.add(ALL_AVAILABLE_ACCELERATORS);
114         acceleratorNames.addAll(NNTestBase.availableAcceleratorNames());
115         final ArrayAdapter<String> acceleratorNamesAdapter = new ArrayAdapter<String>(
116                 this,
117                 android.R.layout.simple_spinner_item, acceleratorNames);
118         final Spinner acceleratorNameSpinner = (Spinner) findViewById(R.id.accelerator_name);
119         acceleratorNameSpinner.setAdapter(acceleratorNamesAdapter);
120         acceleratorNameSpinner.setOnItemSelectedListener(new AdapterView.OnItemSelectedListener() {
121             @Override
122             public void onItemSelected(AdapterView<?> parent, View view, int position, long id) {
123                 mAcceleratorName.set(position == 0 ? null : acceleratorNames.get(position));
124                 mModelsAdapter.clear();
125                 mModelsAdapter.addAll(modelsForAccelerator(mAcceleratorName.get()));
126                 mModelsAdapter.notifyDataSetChanged();
127             }
128 
129             @Override
130             public void onNothingSelected(AdapterView<?> parent) {
131                 onItemSelected(parent, null, 0, 0);
132             }
133         });
134 
135         final ArrayAdapter<String> testTypeAdapter = new ArrayAdapter<String>(this,
136                 android.R.layout.simple_spinner_item,
137                 new String[]{"Separate process", "In process"});
138         testTypeAdapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item);
139         final Spinner testTypeSpinner = (Spinner) findViewById(R.id.test_type);
140         testTypeSpinner.setAdapter(testTypeAdapter);
141         testTypeSpinner.setOnItemSelectedListener(new AdapterView.OnItemSelectedListener() {
142             @Override
143             public void onItemSelected(AdapterView<?> parent, View view, int position, long id) {
144                 mUseSeparateProcess.set(position == 0);
145             }
146 
147             @Override
148             public void onNothingSelected(AdapterView<?> parent) {
149                 mUseSeparateProcess.set(true);
150             }
151         });
152 
153 
154         mThreadCount = (NumberPicker) findViewById(R.id.thread_count);
155         mTestDurationMinutes = (NumberPicker) findViewById(R.id.duration_minutes);
156 
157         mThreadCount.setMinValue(1);
158         mThreadCount.setMaxValue(20);
159 
160         mTestDurationMinutes.setMinValue(1);
161         mTestDurationMinutes.setMaxValue(60);
162 
163         mMmapModel = (CheckBox) findViewById(R.id.mmap_model);
164         mCompileModelsOnly = (CheckBox) findViewById(R.id.compile_only);
165     }
166 
modelsForAccelerator(String acceleratorName)167     private List<String> modelsForAccelerator(String acceleratorName) {
168         List<String> result = new ArrayList<>();
169         if (acceleratorName == null || acceleratorName.equals(ALL_AVAILABLE_ACCELERATORS)) {
170             result.add("All models");
171             result.addAll(mAllTestModels);
172         } else {
173             result.add("All supported models");
174             result.addAll(TestModels.modelsList().stream()
175                     .map(TestModels.TestModelEntry::withDisabledEvaluation).filter(
176                             model -> {
177                                 try {
178                                     return Processor.isTestModelSupportedByAccelerator(
179                                             this,
180                                             model, acceleratorName);
181                                 } catch (NnApiDelegationFailure nnApiDelegationFailure) {
182                                     runOnUiThread(() -> {
183                                         mMessage.append(String.format(
184                                                 "Driver %s failed when trying to check support "
185                                                         + "for model %s!!\n",
186                                                 acceleratorName, model.mModelName));
187                                         ByteArrayOutputStream stsackTraceByteOS =
188                                                 new ByteArrayOutputStream();
189                                         try (PrintStream stackTracePrintStream = new PrintStream(
190                                                 stsackTraceByteOS)) {
191                                             nnApiDelegationFailure.printStackTrace(
192                                                     stackTracePrintStream);
193                                             mMessage.append(
194                                                     stackTracePrintStream.toString() + "\n");
195                                         }
196                                         mStartTestButton.setEnabled(true);
197                                     });
198                                     return false;
199                                 }
200                             }).map(
201                             TestModels.TestModelEntry::getTestName).collect(
202                             Collectors.toList()));
203         }
204 
205         return result;
206     }
207 
testStopped(String msg)208     void testStopped(String msg) {
209         Log.i(TAG, "Test stopped " + msg);
210         mTestRunning.set(false);
211         runOnUiThread(() -> {
212             mMessage.append(msg + "\n");
213             mStartTestButton.setEnabled(true);
214         });
215     }
216 
startTestClicked(View v)217     public void startTestClicked(View v) {
218         Log.i(TAG, "Starting test");
219 
220         if (mTestRunning.getAndSet(true)) {
221             return;
222         }
223 
224         mStartTestButton.setEnabled(false);
225 
226         startInferenceTest();
227     }
228 
229     @SuppressLint("DefaultLocale")
startInferenceTest()230     private void startInferenceTest() {
231         CrashTestCoordinator coordinator = new CrashTestCoordinator(this);
232 
233         int threadCount = mThreadCount.getValue();
234         int testDurationMinutes = mTestDurationMinutes.getValue();
235 
236         int[] testList;
237         if (mSelectedModelIndex.get() < 0) {
238             testList = new int[mModelsAdapter.getCount()];
239             // The first item is the 'all models' or 'all supported models' entry
240             for (int i = 1; i < mModelsAdapter.getCount(); i++) {
241                 String modelName = mModelsAdapter.getItem(i);
242                 testList[i] = mAllTestModels.indexOf(modelName);
243             }
244         } else {
245             testList = new int[]{mSelectedModelIndex.get()};
246         }
247 
248         CrashTestCoordinator.CrashTestCompletionListener testCompletionListener =
249                 new CrashTestCoordinator.CrashTestCompletionListener() {
250                     @Override
251                     public void testCrashed() {
252                         testStopped("Test crashed");
253                     }
254 
255                     @Override
256                     public void testSucceeded() {
257                         testStopped("Test succeeded");
258                     }
259 
260                     @Override
261                     public void testFailed(String reason) {
262                         testStopped("Test failed with reason " + reason);
263                     }
264 
265                     @Override
266                     public void testProgressing(Optional<String> description) {
267                         Log.i(TAG, "> " + description.orElse("Test progressing.."));
268                         // Ignoring message to avoid cluttering the test Text Area
269                         runOnUiThread(() -> mMessage.append("."));
270                     }
271                 };
272 
273         final int testTimeoutMillis = testDurationMinutes * 1500;
274         final String testName = "in-app-test@" + System.currentTimeMillis();
275         final String acceleratorName = mAcceleratorName.get();
276         final boolean mmapModel = mMmapModel.isChecked();
277         final boolean runModelCompilationOnly = mCompileModelsOnly.isChecked();
278         coordinator.startTest(RunModelsInParallel.class,
279                 RunModelsInParallel.intentInitializer(testList, threadCount,
280                         Duration.ofMinutes(testDurationMinutes),
281                         testName, acceleratorName, false, runModelCompilationOnly, mmapModel),
282                 testCompletionListener,
283                 mUseSeparateProcess.get(), testName);
284 
285         mMessage.setText(
286                 String.format(
287                         "%s test started with %d threads for %d minutes on %s\n",
288                         runModelCompilationOnly ? "Compilation" : "Inference",
289                         threadCount,
290                         testDurationMinutes,
291                         acceleratorName != null ? "accelerator " + acceleratorName
292                                 : "NNAPI-selected accelerator"));
293 
294     }
295 
296 }
297