1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.crashtest; 18 19 import android.annotation.SuppressLint; 20 import android.os.Bundle; 21 import android.util.Log; 22 import android.view.View; 23 import android.widget.AdapterView; 24 import android.widget.ArrayAdapter; 25 import android.widget.Button; 26 import android.widget.CheckBox; 27 import android.widget.NumberPicker; 28 import android.widget.Spinner; 29 import android.widget.TextView; 30 31 import androidx.appcompat.app.AppCompatActivity; 32 33 import com.android.nn.benchmark.core.NNTestBase; 34 import com.android.nn.benchmark.core.NnApiDelegationFailure; 35 import com.android.nn.benchmark.core.Processor; 36 import com.android.nn.benchmark.core.TestModels; 37 import com.android.nn.benchmark.core.TestModelsListLoader; 38 import com.android.nn.crashtest.core.CrashTestCoordinator; 39 import com.android.nn.crashtest.core.test.RunModelsInParallel; 40 import com.android.nn.benchmark.util.TestExternalStorageActivity; 41 42 import java.io.ByteArrayOutputStream; 43 import java.io.IOException; 44 import java.io.PrintStream; 45 import java.time.Duration; 46 import java.util.ArrayList; 47 import java.util.List; 48 import java.util.Optional; 49 import java.util.concurrent.atomic.AtomicBoolean; 50 import java.util.concurrent.atomic.AtomicInteger; 51 import java.util.concurrent.atomic.AtomicReference; 52 import java.util.stream.Collectors; 53 54 public class MainActivity extends AppCompatActivity { 55 56 private static final String ALL_AVAILABLE_ACCELERATORS = "All available"; 57 private static final String TAG = "NN_STRESS_TEST"; 58 private static final int JOB_FREQUENCY_MILLIS = 15 * 60 * 1000; // 15 minutes 59 private final AtomicInteger mSelectedModelIndex = new AtomicInteger(-1); 60 private final AtomicBoolean mUseSeparateProcess = new AtomicBoolean(true); 61 private final AtomicReference<String> mAcceleratorName = new AtomicReference<>(null); 62 AtomicBoolean mTestRunning = new AtomicBoolean(false); 63 private Button mStartTestButton; 64 private TextView mMessage; 65 private NumberPicker mThreadCount; 66 private NumberPicker mTestDurationMinutes; 67 private ArrayAdapter<String> mModelsAdapter; 68 private List<String> mAllTestModels; 69 private CheckBox mMmapModel; 70 private CheckBox mCompileModelsOnly; 71 72 @Override onCreate(Bundle savedInstanceState)73 protected void onCreate(Bundle savedInstanceState) { 74 super.onCreate(savedInstanceState); 75 setContentView(R.layout.activity_main); 76 TestExternalStorageActivity.testWriteExternalStorage(this, true); 77 78 mStartTestButton = (Button) findViewById(R.id.start_button); 79 mMessage = (TextView) findViewById(R.id.message); 80 81 try { 82 TestModelsListLoader.parseFromAssets(getAssets()); 83 } catch (IOException e) { 84 Log.e(TAG, "Could not load models", e); 85 } 86 87 mAllTestModels = TestModels.modelsList().stream().map( 88 TestModels.TestModelEntry::getTestName).collect( 89 Collectors.toList()); 90 91 final List<String> modelNames = new ArrayList<>(); 92 modelNames.add("All Models"); 93 modelNames.addAll(modelsForAccelerator(ALL_AVAILABLE_ACCELERATORS)); 94 mModelsAdapter = new ArrayAdapter<String>(this, 95 android.R.layout.simple_spinner_item, 96 modelNames); 97 mModelsAdapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item); 98 final Spinner testModelSpinner = (Spinner) findViewById(R.id.test_model); 99 testModelSpinner.setAdapter(mModelsAdapter); 100 testModelSpinner.setOnItemSelectedListener(new AdapterView.OnItemSelectedListener() { 101 @Override 102 public void onItemSelected(AdapterView<?> parent, View view, int position, long id) { 103 mSelectedModelIndex.set((int) id - 1); 104 } 105 106 @Override 107 public void onNothingSelected(AdapterView<?> parent) { 108 mSelectedModelIndex.set(-1); 109 } 110 }); 111 112 final List<String> acceleratorNames = new ArrayList<>(); 113 acceleratorNames.add(ALL_AVAILABLE_ACCELERATORS); 114 acceleratorNames.addAll(NNTestBase.availableAcceleratorNames()); 115 final ArrayAdapter<String> acceleratorNamesAdapter = new ArrayAdapter<String>( 116 this, 117 android.R.layout.simple_spinner_item, acceleratorNames); 118 final Spinner acceleratorNameSpinner = (Spinner) findViewById(R.id.accelerator_name); 119 acceleratorNameSpinner.setAdapter(acceleratorNamesAdapter); 120 acceleratorNameSpinner.setOnItemSelectedListener(new AdapterView.OnItemSelectedListener() { 121 @Override 122 public void onItemSelected(AdapterView<?> parent, View view, int position, long id) { 123 mAcceleratorName.set(position == 0 ? null : acceleratorNames.get(position)); 124 mModelsAdapter.clear(); 125 mModelsAdapter.addAll(modelsForAccelerator(mAcceleratorName.get())); 126 mModelsAdapter.notifyDataSetChanged(); 127 } 128 129 @Override 130 public void onNothingSelected(AdapterView<?> parent) { 131 onItemSelected(parent, null, 0, 0); 132 } 133 }); 134 135 final ArrayAdapter<String> testTypeAdapter = new ArrayAdapter<String>(this, 136 android.R.layout.simple_spinner_item, 137 new String[]{"Separate process", "In process"}); 138 testTypeAdapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item); 139 final Spinner testTypeSpinner = (Spinner) findViewById(R.id.test_type); 140 testTypeSpinner.setAdapter(testTypeAdapter); 141 testTypeSpinner.setOnItemSelectedListener(new AdapterView.OnItemSelectedListener() { 142 @Override 143 public void onItemSelected(AdapterView<?> parent, View view, int position, long id) { 144 mUseSeparateProcess.set(position == 0); 145 } 146 147 @Override 148 public void onNothingSelected(AdapterView<?> parent) { 149 mUseSeparateProcess.set(true); 150 } 151 }); 152 153 154 mThreadCount = (NumberPicker) findViewById(R.id.thread_count); 155 mTestDurationMinutes = (NumberPicker) findViewById(R.id.duration_minutes); 156 157 mThreadCount.setMinValue(1); 158 mThreadCount.setMaxValue(20); 159 160 mTestDurationMinutes.setMinValue(1); 161 mTestDurationMinutes.setMaxValue(60); 162 163 mMmapModel = (CheckBox) findViewById(R.id.mmap_model); 164 mCompileModelsOnly = (CheckBox) findViewById(R.id.compile_only); 165 } 166 modelsForAccelerator(String acceleratorName)167 private List<String> modelsForAccelerator(String acceleratorName) { 168 List<String> result = new ArrayList<>(); 169 if (acceleratorName == null || acceleratorName.equals(ALL_AVAILABLE_ACCELERATORS)) { 170 result.add("All models"); 171 result.addAll(mAllTestModels); 172 } else { 173 result.add("All supported models"); 174 result.addAll(TestModels.modelsList().stream() 175 .map(TestModels.TestModelEntry::withDisabledEvaluation).filter( 176 model -> { 177 try { 178 return Processor.isTestModelSupportedByAccelerator( 179 this, 180 model, acceleratorName); 181 } catch (NnApiDelegationFailure nnApiDelegationFailure) { 182 runOnUiThread(() -> { 183 mMessage.append(String.format( 184 "Driver %s failed when trying to check support " 185 + "for model %s!!\n", 186 acceleratorName, model.mModelName)); 187 ByteArrayOutputStream stsackTraceByteOS = 188 new ByteArrayOutputStream(); 189 try (PrintStream stackTracePrintStream = new PrintStream( 190 stsackTraceByteOS)) { 191 nnApiDelegationFailure.printStackTrace( 192 stackTracePrintStream); 193 mMessage.append( 194 stackTracePrintStream.toString() + "\n"); 195 } 196 mStartTestButton.setEnabled(true); 197 }); 198 return false; 199 } 200 }).map( 201 TestModels.TestModelEntry::getTestName).collect( 202 Collectors.toList())); 203 } 204 205 return result; 206 } 207 testStopped(String msg)208 void testStopped(String msg) { 209 Log.i(TAG, "Test stopped " + msg); 210 mTestRunning.set(false); 211 runOnUiThread(() -> { 212 mMessage.append(msg + "\n"); 213 mStartTestButton.setEnabled(true); 214 }); 215 } 216 startTestClicked(View v)217 public void startTestClicked(View v) { 218 Log.i(TAG, "Starting test"); 219 220 if (mTestRunning.getAndSet(true)) { 221 return; 222 } 223 224 mStartTestButton.setEnabled(false); 225 226 startInferenceTest(); 227 } 228 229 @SuppressLint("DefaultLocale") startInferenceTest()230 private void startInferenceTest() { 231 CrashTestCoordinator coordinator = new CrashTestCoordinator(this); 232 233 int threadCount = mThreadCount.getValue(); 234 int testDurationMinutes = mTestDurationMinutes.getValue(); 235 236 int[] testList; 237 if (mSelectedModelIndex.get() < 0) { 238 testList = new int[mModelsAdapter.getCount()]; 239 // The first item is the 'all models' or 'all supported models' entry 240 for (int i = 1; i < mModelsAdapter.getCount(); i++) { 241 String modelName = mModelsAdapter.getItem(i); 242 testList[i] = mAllTestModels.indexOf(modelName); 243 } 244 } else { 245 testList = new int[]{mSelectedModelIndex.get()}; 246 } 247 248 CrashTestCoordinator.CrashTestCompletionListener testCompletionListener = 249 new CrashTestCoordinator.CrashTestCompletionListener() { 250 @Override 251 public void testCrashed() { 252 testStopped("Test crashed"); 253 } 254 255 @Override 256 public void testSucceeded() { 257 testStopped("Test succeeded"); 258 } 259 260 @Override 261 public void testFailed(String reason) { 262 testStopped("Test failed with reason " + reason); 263 } 264 265 @Override 266 public void testProgressing(Optional<String> description) { 267 Log.i(TAG, "> " + description.orElse("Test progressing..")); 268 // Ignoring message to avoid cluttering the test Text Area 269 runOnUiThread(() -> mMessage.append(".")); 270 } 271 }; 272 273 final int testTimeoutMillis = testDurationMinutes * 1500; 274 final String testName = "in-app-test@" + System.currentTimeMillis(); 275 final String acceleratorName = mAcceleratorName.get(); 276 final boolean mmapModel = mMmapModel.isChecked(); 277 final boolean runModelCompilationOnly = mCompileModelsOnly.isChecked(); 278 coordinator.startTest(RunModelsInParallel.class, 279 RunModelsInParallel.intentInitializer(testList, threadCount, 280 Duration.ofMinutes(testDurationMinutes), 281 testName, acceleratorName, false, runModelCompilationOnly, mmapModel), 282 testCompletionListener, 283 mUseSeparateProcess.get(), testName); 284 285 mMessage.setText( 286 String.format( 287 "%s test started with %d threads for %d minutes on %s\n", 288 runModelCompilationOnly ? "Compilation" : "Inference", 289 threadCount, 290 testDurationMinutes, 291 acceleratorName != null ? "accelerator " + acceleratorName 292 : "NNAPI-selected accelerator")); 293 294 } 295 296 } 297