1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.crashtest.core.test; 18 19 import android.annotation.SuppressLint; 20 import android.content.Context; 21 import android.content.Intent; 22 import android.text.TextUtils; 23 import android.util.Log; 24 25 26 import com.android.nn.crashtest.core.CrashTest; 27 import com.android.nn.crashtest.core.CrashTestCoordinator; 28 29 import java.io.File; 30 import java.time.Duration; 31 import java.time.LocalDateTime; 32 import java.util.Optional; 33 34 public class RandomGraphTest implements CrashTest { 35 private static final String TAG = "NN_RAND_MODEL"; 36 37 private static final boolean ENABLE_NNAPI_LOGS = false; 38 getGeneratorOutFilePath(String fileExtension)39 private String getGeneratorOutFilePath(String fileExtension) { 40 return mContext.getExternalFilesDir(null).getAbsolutePath() + "/" 41 + mTestName.hashCode() + "." + fileExtension; 42 } 43 getNnapiLogFilePath()44 private String getNnapiLogFilePath() { 45 if (ENABLE_NNAPI_LOGS) { 46 String logFile = getGeneratorOutFilePath("model.py"); 47 Log.d(TAG, String.format("Writing NNAPI Fuzzer logs to %s", logFile)); 48 return logFile; 49 } else { 50 return ""; 51 } 52 } 53 getFailedModelDumpPath()54 private String getFailedModelDumpPath() { 55 return getGeneratorOutFilePath("log"); 56 } 57 58 static { 59 System.loadLibrary("random_graph_test_jni"); 60 } 61 62 private enum RandomModelExecutionResult { 63 // This is the java translation of the RandomModelExecutionResult c++ enum in 64 // random_graph_test_jni.cpp 65 kSuccess(0, ""), 66 kFailedCompilation(1, "Compilation failed"), 67 kFailedExecution(2, "Execution failed"), 68 kFailedOtherNnApiCall(3, 69 "Failure trying to interact with the driver"), 70 kInvalidModelGenerated(4, "Unable to generate a valid model"), 71 kUnsupportedModelGenerated(5, "Unable to generate a model supported by the driver"); 72 73 74 private final int mValue; 75 private final String mDescription; 76 RandomModelExecutionResult(int value, String description)77 RandomModelExecutionResult(int value, String description) { 78 mValue = value; 79 mDescription = description; 80 } 81 fromNativeResult(int nativeResult)82 public static RandomModelExecutionResult fromNativeResult(int nativeResult) { 83 for (RandomModelExecutionResult currValue : RandomModelExecutionResult.values()) { 84 if (currValue.mValue == nativeResult) { 85 return currValue; 86 } 87 } 88 throw new IllegalArgumentException( 89 String.format("Invalid native result value %d", nativeResult)); 90 } 91 } 92 93 public static final String MAX_TEST_DURATION = "max_test_duration"; 94 public static final String GRAPH_SIZE = "graph_size"; 95 public static final String DIMENSIONS_RANGE = "dimensions_range"; 96 public static final String MODELS_COUNT = "models_count"; 97 public static final String PAUSE_BETWEEN_MODELS_MS = "pause_between_models_ms"; 98 public static final String COMPILATION_ONLY = "compilation_only"; 99 public static final String DEVICE_NAME = "device_name"; 100 public static final String TEST_NAME = "test_name"; 101 102 public static final int DEFAULT_GRAPH_SIZE = 100; 103 public static final int DEFAULT_DIMENSIONS_RANGE = 100; 104 public static final int DEFAULT_MODELS_COUNT = 100; 105 public static final long DEFAULT_PAUSE_BETWEEN_MODELS_MILLIS = 300; 106 public static final boolean DEFAULT_COMPILATION_ONLY = false; 107 public static final long DEFAULT_MAX_TEST_DURATION_MILLIS = Duration.ofMinutes(2).toMillis(); 108 private static final long MAX_TIME_TO_LOOK_FOR_SUITABLE_MODEL_SECONDS = 30; 109 intentInitializer(int graphSize, int dimensionsRange, int modelsCount, long pauseBetweenModelsMillis, boolean compilationOnly, String deviceName, long maxTestDurationMillis, String testName)110 static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer(int graphSize, 111 int dimensionsRange, int modelsCount, long pauseBetweenModelsMillis, 112 boolean compilationOnly, String deviceName, long maxTestDurationMillis, 113 String testName) { 114 return intent -> { 115 intent.putExtra(GRAPH_SIZE, graphSize); 116 intent.putExtra(DIMENSIONS_RANGE, dimensionsRange); 117 intent.putExtra(MODELS_COUNT, modelsCount); 118 intent.putExtra(PAUSE_BETWEEN_MODELS_MS, pauseBetweenModelsMillis); 119 intent.putExtra(COMPILATION_ONLY, compilationOnly); 120 intent.putExtra(DEVICE_NAME, deviceName); 121 intent.putExtra(MAX_TEST_DURATION, maxTestDurationMillis); 122 intent.putExtra(TEST_NAME, testName); 123 }; 124 } 125 intentInitializer( Intent copyFrom)126 static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer( 127 Intent copyFrom) { 128 return intentInitializer( 129 copyFrom.getIntExtra(RandomGraphTest.GRAPH_SIZE, 130 RandomGraphTest.DEFAULT_GRAPH_SIZE), 131 copyFrom.getIntExtra( 132 RandomGraphTest.DIMENSIONS_RANGE, RandomGraphTest.DEFAULT_DIMENSIONS_RANGE), 133 copyFrom.getIntExtra(RandomGraphTest.MODELS_COUNT, 134 RandomGraphTest.DEFAULT_MODELS_COUNT), 135 copyFrom.getLongExtra(RandomGraphTest.PAUSE_BETWEEN_MODELS_MS, 136 RandomGraphTest.DEFAULT_PAUSE_BETWEEN_MODELS_MILLIS), 137 copyFrom.getBooleanExtra( 138 RandomGraphTest.COMPILATION_ONLY, RandomGraphTest.DEFAULT_COMPILATION_ONLY), 139 copyFrom.getStringExtra(RandomGraphTest.DEVICE_NAME), 140 copyFrom.getLongExtra(MAX_TEST_DURATION, 141 DEFAULT_MAX_TEST_DURATION_MILLIS), 142 copyFrom.getStringExtra(RandomGraphTest.TEST_NAME)); 143 } 144 145 private Context mContext; 146 private String mDeviceName; 147 private boolean mCompilationOnly; 148 private int mGraphSize; 149 private int mDimensionsRange; 150 private int mModelsCount; 151 private long mPauseBetweenModelsMillis; 152 private Duration mMaxTestDuration; 153 private String mTestName; 154 createRandomGraphGenerator(String nnApiDeviceName, int numOperations, int dimensionRange, String testName, String nnapiLogPath, String failedModelDumpPath)155 public static native long createRandomGraphGenerator(String nnApiDeviceName, int numOperations, 156 int dimensionRange, 157 String testName, String nnapiLogPath, String failedModelDumpPath); 158 destroyRandomGraphGenerator(long generatorHandle)159 public static native long destroyRandomGraphGenerator(long generatorHandle); 160 runRandomModel(long generatorHandle, boolean compilationOnly, long maxModelSearchTimeSeconds)161 private static native int runRandomModel(long generatorHandle, 162 boolean compilationOnly, long maxModelSearchTimeSeconds); 163 164 @Override init(Context context, Intent configParams, Optional<ProgressListener> progressListener)165 public void init(Context context, Intent configParams, 166 Optional<ProgressListener> progressListener) { 167 mContext = context; 168 mDeviceName = configParams.getStringExtra(DEVICE_NAME); 169 mCompilationOnly = configParams.getBooleanExtra(COMPILATION_ONLY, DEFAULT_COMPILATION_ONLY); 170 mGraphSize = configParams.getIntExtra(GRAPH_SIZE, DEFAULT_GRAPH_SIZE); 171 mDimensionsRange = configParams.getIntExtra(DIMENSIONS_RANGE, DEFAULT_DIMENSIONS_RANGE); 172 mModelsCount = configParams.getIntExtra(MODELS_COUNT, DEFAULT_MODELS_COUNT); 173 mPauseBetweenModelsMillis = 174 configParams.getLongExtra(PAUSE_BETWEEN_MODELS_MS, 175 DEFAULT_PAUSE_BETWEEN_MODELS_MILLIS); 176 mMaxTestDuration = 177 Duration.ofMillis(configParams.getLongExtra(MAX_TEST_DURATION, 178 DEFAULT_MAX_TEST_DURATION_MILLIS)); 179 mTestName = configParams.getStringExtra(TEST_NAME) != null 180 ? configParams.getStringExtra(TEST_NAME) 181 : "no-name"; 182 } 183 184 @SuppressLint("DefaultLocale") 185 @Override call()186 public Optional<String> call() throws Exception { 187 LocalDateTime testStart = LocalDateTime.now(); 188 Log.i(TAG, 189 String.format(String.format( 190 "Starting test '%s', testing %d models of size %d and dimension range %d " 191 + "for a max duration of %s on device %s.", 192 mTestName, mModelsCount, mGraphSize, mDimensionsRange, mMaxTestDuration, 193 mDeviceName != null ? mDeviceName : "no-device"))); 194 195 final long generatorHandle = RandomGraphTest.createRandomGraphGenerator(mDeviceName, 196 mGraphSize, mDimensionsRange, mTestName, getNnapiLogFilePath(), 197 getFailedModelDumpPath()); 198 if (generatorHandle == 0) { 199 Log.e(TAG, "Unable to initialize random graph generator, failing test"); 200 return failure("Unable to initialize random graph generator"); 201 } 202 try { 203 for (int i = 0; i < mModelsCount; i++) { 204 if (Duration.between(testStart, LocalDateTime.now()).plus( 205 Duration.ofSeconds(MAX_TIME_TO_LOOK_FOR_SUITABLE_MODEL_SECONDS)).compareTo( 206 mMaxTestDuration) 207 >= 0) { 208 Log.d(TAG, "Max test duration reached, ending test"); 209 break; 210 } 211 212 int nativeExecutionResult = runRandomModel(generatorHandle, 213 mCompilationOnly, MAX_TIME_TO_LOOK_FOR_SUITABLE_MODEL_SECONDS); 214 215 RandomModelExecutionResult executionResult = 216 RandomModelExecutionResult.fromNativeResult(nativeExecutionResult); 217 218 if (executionResult != RandomModelExecutionResult.kSuccess) { 219 Log.w(TAG, String.format( 220 "Received failure result '%s' at iteration %d, failing", 221 executionResult.mDescription, i)); 222 if (executionResult == RandomModelExecutionResult.kFailedExecution || 223 executionResult == RandomModelExecutionResult.kFailedCompilation) { 224 Log.i(TAG, String.format("Model has been dumped at path '%s'", 225 getFailedModelDumpPath())); 226 } else if ( 227 executionResult == RandomModelExecutionResult.kUnsupportedModelGenerated 228 || executionResult 229 == RandomModelExecutionResult.kInvalidModelGenerated) { 230 Log.w(TAG, String.format( 231 "Unable to find a valid model for test '%s', returning success " 232 + "anyway", 233 mTestName)); 234 235 return success(); 236 } 237 238 return failure(executionResult.mDescription); 239 } else if (!TextUtils.isEmpty(getNnapiLogFilePath())) { 240 (new File(getNnapiLogFilePath())).delete(); 241 } 242 243 Thread.sleep(mPauseBetweenModelsMillis); 244 } 245 246 return success(); 247 } finally { 248 RandomGraphTest.destroyRandomGraphGenerator(generatorHandle); 249 } 250 } 251 } 252