1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.crashtest.core.test; 18 19 import android.annotation.SuppressLint; 20 import android.content.Context; 21 import android.content.Intent; 22 import android.util.Log; 23 24 import com.android.nn.benchmark.core.BenchmarkException; 25 import com.android.nn.benchmark.core.BenchmarkResult; 26 import com.android.nn.benchmark.core.Processor; 27 import com.android.nn.benchmark.core.TestModels; 28 import com.android.nn.benchmark.core.TfLiteBackend; 29 import com.android.nn.crashtest.app.AcceleratorSpecificTestSupport; 30 import com.android.nn.crashtest.core.CrashTest; 31 import com.android.nn.crashtest.core.CrashTestCoordinator; 32 33 import java.io.IOException; 34 import java.util.Arrays; 35 import java.util.List; 36 import java.util.Optional; 37 import java.util.concurrent.Callable; 38 import java.util.concurrent.CountDownLatch; 39 import java.util.concurrent.ExecutionException; 40 import java.util.concurrent.ExecutorService; 41 import java.util.concurrent.Executors; 42 import java.util.concurrent.Future; 43 import java.util.stream.Stream; 44 45 public class PerformanceDegradationTest implements CrashTest { 46 public static final String TAG = "NN_PERF_DEG"; 47 48 private static final Processor.Callback mNoOpCallback = new Processor.Callback() { 49 @Override 50 public void onBenchmarkFinish(boolean ok) { 51 } 52 53 @Override 54 public void onStatusUpdate(int testNumber, int numTests, String modelName) { 55 } 56 }; 57 58 public static final String WARMUP_SECONDS = "warmup_seconds"; 59 public static final String RUN_TIME_SECONDS = "run_time_seconds"; 60 public static final String ACCELERATOR_NAME = "accelerator_name"; 61 public static final float DEFAULT_WARMUP_SECONDS = 3.0f; 62 public static final float DEFAULT_RUN_TIME_SECONDS = 10.0f; 63 public static final String THREAD_COUNT = "thread_count"; 64 public static final int DEFAULT_THREAD_COUNT = 5; 65 public static final String MAX_PERFORMANCE_DEGRADATION = "max_performance_degradation"; 66 public static final int DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE = 100; 67 public static final String TEST_NAME = "test_name"; 68 private static final long INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS = 500; 69 intentInitializer( float warmupTimeSeconds, float runTimeSeconds, String acceleratorName, int threadCount, int maxPerformanceDegradationPercent, String testName)70 static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer( 71 float warmupTimeSeconds, float runTimeSeconds, String acceleratorName, int threadCount, 72 int maxPerformanceDegradationPercent, String testName) { 73 return intent -> { 74 intent.putExtra(WARMUP_SECONDS, warmupTimeSeconds); 75 intent.putExtra(RUN_TIME_SECONDS, runTimeSeconds); 76 intent.putExtra(ACCELERATOR_NAME, acceleratorName); 77 intent.putExtra(THREAD_COUNT, threadCount); 78 intent.putExtra(MAX_PERFORMANCE_DEGRADATION, maxPerformanceDegradationPercent); 79 intent.putExtra(TEST_NAME, testName); 80 }; 81 } 82 intentInitializer( Intent copyFrom)83 static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer( 84 Intent copyFrom) { 85 return intentInitializer( 86 copyFrom.getFloatExtra(WARMUP_SECONDS, DEFAULT_WARMUP_SECONDS), 87 copyFrom.getFloatExtra(RUN_TIME_SECONDS, DEFAULT_RUN_TIME_SECONDS), 88 copyFrom.getStringExtra(ACCELERATOR_NAME), 89 copyFrom.getIntExtra(THREAD_COUNT, DEFAULT_THREAD_COUNT), 90 copyFrom.getIntExtra(MAX_PERFORMANCE_DEGRADATION, 91 DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE), 92 copyFrom.getStringExtra(TEST_NAME)); 93 } 94 95 private Context mContext; 96 private float mWarmupTimeSeconds; 97 private float mRunTimeSeconds; 98 private String mAcceleratorName; 99 private int mThreadCount; 100 private int mMaxPerformanceDegradationPercent; 101 private String mTestName; 102 103 @Override init(Context context, Intent configParams, Optional<ProgressListener> progressListener)104 public void init(Context context, Intent configParams, 105 Optional<ProgressListener> progressListener) { 106 mContext = context; 107 108 mWarmupTimeSeconds = configParams.getFloatExtra(WARMUP_SECONDS, DEFAULT_WARMUP_SECONDS); 109 mRunTimeSeconds = configParams.getFloatExtra(RUN_TIME_SECONDS, DEFAULT_RUN_TIME_SECONDS); 110 mAcceleratorName = configParams.getStringExtra(ACCELERATOR_NAME); 111 mThreadCount = configParams.getIntExtra(THREAD_COUNT, DEFAULT_THREAD_COUNT); 112 mMaxPerformanceDegradationPercent = configParams.getIntExtra(MAX_PERFORMANCE_DEGRADATION, 113 DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE); 114 mTestName = configParams.getStringExtra(TEST_NAME); 115 } 116 117 @SuppressLint("DefaultLocale") 118 @Override call()119 public Optional<String> call() throws Exception { 120 List<TestModels.TestModelEntry> modelsForAccelerator = 121 AcceleratorSpecificTestSupport.findAllTestModelsRunningOnAccelerator(mContext, 122 mAcceleratorName); 123 124 if (modelsForAccelerator.isEmpty()) { 125 return failure("Cannot find any model to use for testing"); 126 } 127 128 Log.i(TAG, String.format("Checking performance degradation using %d models", 129 modelsForAccelerator.size())); 130 131 TestModels.TestModelEntry modelForInference = modelsForAccelerator.get(0); 132 // The performance degradation is strongly dependent on the model used to compile 133 // so we check all the available ones. 134 for (TestModels.TestModelEntry modelForCompilation : modelsForAccelerator) { 135 Optional<String> currTestResult = testDegradationForModels(modelForInference, 136 modelForCompilation); 137 if (isFailure(currTestResult)) { 138 return currTestResult; 139 } 140 } 141 142 return success(); 143 } 144 145 @SuppressLint("DefaultLocale") testDegradationForModels( TestModels.TestModelEntry inferenceModelEntry, TestModels.TestModelEntry compilationModelEntry)146 public Optional<String> testDegradationForModels( 147 TestModels.TestModelEntry inferenceModelEntry, 148 TestModels.TestModelEntry compilationModelEntry) throws Exception { 149 Log.i(TAG, String.format( 150 "Testing degradation in inference of model %s when running %d threads compliing " 151 + "model %s", 152 inferenceModelEntry.mModelName, mThreadCount, compilationModelEntry.mModelName)); 153 154 Log.d(TAG, String.format("%s: Calculating baseline", mTestName)); 155 // first let's measure a baseline performance 156 final BenchmarkResult baseline = modelPerformanceCollector(inferenceModelEntry, 157 /*start=*/ null).call(); 158 if (baseline.hasBenchmarkError()) { 159 return failure(String.format("%s: Baseline has benchmark error '%s'", 160 mTestName, baseline.getBenchmarkError())); 161 } 162 Log.d(TAG, String.format("%s: Baseline mean time is %f seconds", mTestName, 163 baseline.getMeanTimeSec())); 164 165 Log.d(TAG, String.format("%s: Sleeping for %d millis", mTestName, 166 INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS)); 167 Thread.sleep(INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS); 168 169 Log.d(TAG, String.format("%s: Calculating performance with %d threads", mTestName, 170 mThreadCount)); 171 final int totalThreadCount = mThreadCount + 1; 172 final CountDownLatch start = new CountDownLatch(totalThreadCount); 173 ModelCompiler[] compilers = Stream.generate( 174 () -> new ModelCompiler(start, mContext, mAcceleratorName, 175 compilationModelEntry)).limit( 176 mThreadCount).toArray( 177 ModelCompiler[]::new); 178 179 Callable<BenchmarkResult> performanceWithOtherCompilingThreadCollector = 180 modelPerformanceCollector(inferenceModelEntry, start); 181 182 ExecutorService testExecutor = Executors.newFixedThreadPool(totalThreadCount); 183 Future<?>[] compilerFutures = Arrays.stream(compilers).map(testExecutor::submit).toArray( 184 Future[]::new); 185 BenchmarkResult benchmarkWithOtherCompilingThread = testExecutor.submit( 186 performanceWithOtherCompilingThreadCollector).get(); 187 188 Arrays.stream(compilers).forEach(ModelCompiler::stop); 189 Arrays.stream(compilerFutures).forEach(future -> { 190 try { 191 future.get(); 192 } catch (InterruptedException | ExecutionException e) { 193 Log.e(TAG, "Error waiting for compiler process completion", e); 194 } 195 }); 196 197 if (benchmarkWithOtherCompilingThread.hasBenchmarkError()) { 198 return failure( 199 String.format( 200 "%s: Test with parallel compiling thrads has benchmark error '%s'", 201 mTestName, benchmarkWithOtherCompilingThread.getBenchmarkError())); 202 } 203 204 Log.d(TAG, String.format("%s: Multithreaded mean time is %f seconds", 205 mTestName, benchmarkWithOtherCompilingThread.getMeanTimeSec())); 206 207 int performanceDegradation = (int) (((benchmarkWithOtherCompilingThread.getMeanTimeSec() 208 / baseline.getMeanTimeSec()) - 1.0) * 100); 209 210 Log.i(TAG, String.format( 211 "%s: Performance degradation for accelerator %s, with %d threads is %d%%. " 212 + "Threshold " 213 + "is %d%%", 214 mTestName, mAcceleratorName, mThreadCount, performanceDegradation, 215 mMaxPerformanceDegradationPercent)); 216 217 if (performanceDegradation > mMaxPerformanceDegradationPercent) { 218 return failure(String.format("Performance degradation is %d%%. Max acceptable is %d%%", 219 performanceDegradation, mMaxPerformanceDegradationPercent)); 220 } 221 222 return success(); 223 } 224 225 modelPerformanceCollector( final TestModels.TestModelEntry inferenceModelEntry, final CountDownLatch start)226 private Callable<BenchmarkResult> modelPerformanceCollector( 227 final TestModels.TestModelEntry inferenceModelEntry, final CountDownLatch start) { 228 return () -> { 229 Processor benchmarkProcessor = new Processor(mContext, mNoOpCallback, new int[0]); 230 benchmarkProcessor.setTfLiteBackend(TfLiteBackend.NNAPI); 231 benchmarkProcessor.setNnApiAcceleratorName(mAcceleratorName); 232 if (start != null) { 233 start.countDown(); 234 start.await(); 235 } 236 final BenchmarkResult result = 237 benchmarkProcessor.getInstrumentationResult( 238 inferenceModelEntry, mWarmupTimeSeconds, mRunTimeSeconds); 239 240 return result; 241 }; 242 } 243 244 private static class ModelCompiler implements Callable<Void> { 245 private static final long SLEEP_BETWEEN_COMPILATION_INTERVAL_MS = 20; 246 private final CountDownLatch mStart; 247 private final Processor mProcessor; 248 private final TestModels.TestModelEntry mTestModelEntry; 249 private volatile boolean mRun; 250 251 ModelCompiler(final CountDownLatch start, final Context context, 252 final String acceleratorName, TestModels.TestModelEntry testModelEntry) { 253 mStart = start; 254 mTestModelEntry = testModelEntry; 255 mProcessor = new Processor(context, mNoOpCallback, new int[0]); 256 mProcessor.setTfLiteBackend(TfLiteBackend.NNAPI); 257 mProcessor.setNnApiAcceleratorName(acceleratorName); 258 mProcessor.setRunModelCompilationOnly(true); 259 mRun = true; 260 } 261 262 @Override 263 public Void call() throws IOException, BenchmarkException { 264 if (mStart != null) { 265 try { 266 mStart.countDown(); 267 mStart.await(); 268 } catch (InterruptedException e) { 269 Thread.interrupted(); 270 Log.i(TAG, "Interrupted, stopping processing"); 271 return null; 272 } 273 } 274 while (mRun) { 275 mProcessor.getInstrumentationResult(mTestModelEntry, 0, 0); 276 try { 277 Thread.sleep(SLEEP_BETWEEN_COMPILATION_INTERVAL_MS); 278 } catch (InterruptedException e) { 279 Thread.interrupted(); 280 return null; 281 } 282 } 283 return null; 284 } 285 286 public void stop() { 287 mRun = false; 288 } 289 } 290 } 291