1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.crashtest.app; 18 19 import android.content.Context; 20 import android.util.Log; 21 22 import androidx.test.InstrumentationRegistry; 23 24 import com.android.nn.benchmark.core.BenchmarkException; 25 import com.android.nn.benchmark.core.BenchmarkResult; 26 import com.android.nn.benchmark.core.NNTestBase; 27 import com.android.nn.benchmark.core.NnApiDelegationFailure; 28 import com.android.nn.benchmark.core.Processor; 29 import com.android.nn.benchmark.core.TestModels; 30 import com.android.nn.benchmark.core.TfLiteBackend; 31 32 import java.io.IOException; 33 import java.util.ArrayList; 34 import java.util.Arrays; 35 import java.util.List; 36 import java.util.Optional; 37 import java.util.concurrent.Callable; 38 import java.util.concurrent.atomic.AtomicBoolean; 39 import java.util.stream.Collectors; 40 41 public interface AcceleratorSpecificTestSupport { 42 String TAG = "AcceleratorTest"; 43 findTestModelRunningOnAccelerator( Context context, String acceleratorName)44 static Optional<TestModels.TestModelEntry> findTestModelRunningOnAccelerator( 45 Context context, String acceleratorName) throws NnApiDelegationFailure { 46 for (TestModels.TestModelEntry model : TestModels.modelsList()) { 47 if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) { 48 return Optional.of(model); 49 } 50 } 51 return Optional.empty(); 52 } 53 findAllTestModelsRunningOnAccelerator( Context context, String acceleratorName)54 static List<TestModels.TestModelEntry> findAllTestModelsRunningOnAccelerator( 55 Context context, String acceleratorName) throws NnApiDelegationFailure { 56 List<TestModels.TestModelEntry> result = new ArrayList<>(); 57 for (TestModels.TestModelEntry model : TestModels.modelsList()) { 58 if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) { 59 result.add(model); 60 } 61 } 62 return result; 63 } 64 ramdomInRange(long min, long max)65 default long ramdomInRange(long min, long max) { 66 return min + (long) (Math.random() * (max - min)); 67 } 68 getTestParameter(String key, String defaultValue)69 static String getTestParameter(String key, String defaultValue) { 70 return InstrumentationRegistry.getArguments().getString(key, defaultValue); 71 } 72 getBooleanTestParameter(String key, boolean defaultValue)73 static boolean getBooleanTestParameter(String key, boolean defaultValue) { 74 // All instrumentation arguments are passed as String so I have to convert the value here. 75 return Boolean.parseBoolean( 76 InstrumentationRegistry.getArguments().getString(key, "" + defaultValue)); 77 } 78 79 static final String ACCELERATOR_FILTER_PROPERTY = "nnCrashtestDeviceFilter"; 80 static final String INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY = 81 "nnCrashtestIncludeNnapiReference"; 82 getTargetAcceleratorNames()83 static List<String> getTargetAcceleratorNames() { 84 List<String> accelerators = new ArrayList<>(); 85 String acceleratorFilter = getTestParameter(ACCELERATOR_FILTER_PROPERTY, ".+"); 86 accelerators.addAll(NNTestBase.availableAcceleratorNames().stream().filter( 87 name -> name.matches(acceleratorFilter)).collect( 88 Collectors.toList())); 89 if (getBooleanTestParameter(INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY, false)) { 90 accelerators.add(null); // running tests with no specified target accelerator too 91 } 92 return accelerators; 93 } 94 95 perAcceleratorTestConfig(List<Object[]> testConfig)96 static List<Object[]> perAcceleratorTestConfig(List<Object[]> testConfig) { 97 return testConfig.stream() 98 .flatMap(currConfigurationParams -> getTargetAcceleratorNames().stream().map( 99 accelerator -> { 100 Object[] result = 101 Arrays.copyOf(currConfigurationParams, 102 currConfigurationParams.length + 1); 103 result[currConfigurationParams.length] = accelerator; 104 return result; 105 })) 106 .collect(Collectors.toList()); 107 } 108 109 class DriverLivenessChecker implements Callable<Boolean> { 110 final Processor mProcessor; 111 private final AtomicBoolean mRun = new AtomicBoolean(true); 112 private final TestModels.TestModelEntry mTestModelEntry; 113 DriverLivenessChecker(Context context, String acceleratorName, TestModels.TestModelEntry testModelEntry)114 public DriverLivenessChecker(Context context, String acceleratorName, 115 TestModels.TestModelEntry testModelEntry) { 116 mProcessor = new Processor(context, 117 new Processor.Callback() { 118 @Override 119 public void onBenchmarkFinish(boolean ok) { 120 } 121 122 @Override 123 public void onStatusUpdate(int testNumber, int numTests, String modelName) { 124 } 125 }, new int[0]); 126 mProcessor.setTfLiteBackend(TfLiteBackend.NNAPI); 127 mProcessor.setCompleteInputSet(false); 128 mProcessor.setNnApiAcceleratorName(acceleratorName); 129 mTestModelEntry = testModelEntry; 130 } 131 stop()132 public void stop() { 133 mRun.set(false); 134 } 135 136 @Override call()137 public Boolean call() throws Exception { 138 while (mRun.get()) { 139 try { 140 BenchmarkResult modelExecutionResult = mProcessor.getInstrumentationResult( 141 mTestModelEntry, 0, 3); 142 if (modelExecutionResult.hasBenchmarkError()) { 143 Log.e(TAG, String.format("Benchmark failed with message %s", 144 modelExecutionResult.getBenchmarkError())); 145 return false; 146 } 147 } catch (IOException | BenchmarkException e) { 148 Log.e(TAG, String.format("Error running model %s", mTestModelEntry.mModelName)); 149 return false; 150 } 151 } 152 153 return true; 154 } 155 } 156 } 157