1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.os.Bundle; 20 import android.os.Parcel; 21 import android.os.Parcelable; 22 import android.text.TextUtils; 23 import android.util.Pair; 24 25 import java.util.ArrayList; 26 import java.util.Arrays; 27 import java.util.List; 28 29 public class BenchmarkResult implements Parcelable { 30 // Used by CTS tests. 31 public final static String BACKEND_TFLITE_NNAPI = "TFLite_NNAPI"; 32 public final static String BACKEND_TFLITE_CPU = "TFLite_CPU"; 33 34 private final static int TIME_FREQ_ARRAY_SIZE = 32; 35 36 /** The name of the benchmark */ 37 private String mTestInfo; 38 39 /** Latency results */ 40 private LatencyResult mLatencyInference; 41 private LatencyResult mLatencyCompileWithoutCache; 42 private LatencyResult mLatencySaveToCache; 43 private LatencyResult mLatencyPrepareFromCache; 44 45 /** Accuracy results */ 46 private float mSumOfMSEs; 47 private float mMaxSingleError; 48 private int mNumberOfEvaluatorResults; 49 private String[] mEvaluatorKeys = {}; 50 private float[] mEvaluatorResults = {}; 51 52 /** Type of backend used for inference */ 53 private String mBackendType; 54 55 /** Size of test set using for inference */ 56 private int mTestSetSize; 57 58 /** Size of compilation cache files in bytes */ 59 private int mCompilationCacheSizeBytes = 0; 60 61 /** List of validation errors */ 62 private String[] mValidationErrors = {}; 63 64 /** Error that prevents the benchmark from running, e.g. SDK version not supported. */ 65 private String mBenchmarkError; 66 BenchmarkResult(LatencyResult inferenceLatency, float sumOfMSEs, float maxSingleError, String testInfo, String[] evaluatorKeys, float[] evaluatorResults, String backendType, int testSetSize, String[] validationErrors)67 public BenchmarkResult(LatencyResult inferenceLatency, 68 float sumOfMSEs, float maxSingleError, String testInfo, 69 String[] evaluatorKeys, float[] evaluatorResults, 70 String backendType, int testSetSize, String[] validationErrors) { 71 mLatencyInference = inferenceLatency; 72 mSumOfMSEs = sumOfMSEs; 73 mMaxSingleError = maxSingleError; 74 mTestInfo = testInfo; 75 mBackendType = backendType; 76 mTestSetSize = testSetSize; 77 if (validationErrors == null) { 78 mValidationErrors = new String[0]; 79 } else { 80 mValidationErrors = validationErrors; 81 } 82 83 if (evaluatorKeys == null) { 84 mEvaluatorKeys = new String[0]; 85 } else { 86 mEvaluatorKeys = evaluatorKeys; 87 } 88 if (evaluatorResults == null) { 89 mEvaluatorResults = new float[0]; 90 } else { 91 mEvaluatorResults = evaluatorResults; 92 } 93 if (mEvaluatorResults.length != mEvaluatorKeys.length) { 94 throw new IllegalArgumentException("Different number of evaluator keys vs values"); 95 } 96 mNumberOfEvaluatorResults = mEvaluatorResults.length; 97 } 98 BenchmarkResult(String benchmarkError)99 public BenchmarkResult(String benchmarkError) { 100 mBenchmarkError = benchmarkError; 101 } 102 hasValidationErrors()103 public boolean hasValidationErrors() { 104 return mValidationErrors.length > 0; 105 } 106 BenchmarkResult(Parcel in)107 protected BenchmarkResult(Parcel in) { 108 mLatencyInference = in.readParcelable(LatencyResult.class.getClassLoader()); 109 mLatencyCompileWithoutCache = in.readParcelable(LatencyResult.class.getClassLoader()); 110 mLatencySaveToCache = in.readParcelable(LatencyResult.class.getClassLoader()); 111 mLatencyPrepareFromCache = in.readParcelable(LatencyResult.class.getClassLoader()); 112 mSumOfMSEs = in.readFloat(); 113 mMaxSingleError = in.readFloat(); 114 mTestInfo = in.readString(); 115 mNumberOfEvaluatorResults = in.readInt(); 116 mEvaluatorKeys = new String[mNumberOfEvaluatorResults]; 117 in.readStringArray(mEvaluatorKeys); 118 mEvaluatorResults = new float[mNumberOfEvaluatorResults]; 119 in.readFloatArray(mEvaluatorResults); 120 if (mEvaluatorResults.length != mEvaluatorKeys.length) { 121 throw new IllegalArgumentException("Different number of evaluator keys vs values"); 122 } 123 mBackendType = in.readString(); 124 mTestSetSize = in.readInt(); 125 mCompilationCacheSizeBytes = in.readInt(); 126 int validationsErrorsSize = in.readInt(); 127 mValidationErrors = new String[validationsErrorsSize]; 128 in.readStringArray(mValidationErrors); 129 mBenchmarkError = in.readString(); 130 } 131 132 @Override describeContents()133 public int describeContents() { 134 return 0; 135 } 136 137 @Override writeToParcel(Parcel dest, int flags)138 public void writeToParcel(Parcel dest, int flags) { 139 dest.writeParcelable(mLatencyInference, flags); 140 dest.writeParcelable(mLatencyCompileWithoutCache, flags); 141 dest.writeParcelable(mLatencySaveToCache, flags); 142 dest.writeParcelable(mLatencyPrepareFromCache, flags); 143 dest.writeFloat(mSumOfMSEs); 144 dest.writeFloat(mMaxSingleError); 145 dest.writeString(mTestInfo); 146 dest.writeInt(mNumberOfEvaluatorResults); 147 dest.writeStringArray(mEvaluatorKeys); 148 dest.writeFloatArray(mEvaluatorResults); 149 dest.writeString(mBackendType); 150 dest.writeInt(mTestSetSize); 151 dest.writeInt(mCompilationCacheSizeBytes); 152 dest.writeInt(mValidationErrors.length); 153 dest.writeStringArray(mValidationErrors); 154 dest.writeString(mBenchmarkError); 155 } 156 157 @SuppressWarnings("unused") 158 public static final Parcelable.Creator<BenchmarkResult> CREATOR = 159 new Parcelable.Creator<BenchmarkResult>() { 160 @Override 161 public BenchmarkResult createFromParcel(Parcel in) { 162 return new BenchmarkResult(in); 163 } 164 165 @Override 166 public BenchmarkResult[] newArray(int size) { 167 return new BenchmarkResult[size]; 168 } 169 }; 170 getError()171 public float getError() { 172 return mSumOfMSEs; 173 } 174 getMeanTimeSec()175 public float getMeanTimeSec() { 176 return mLatencyInference.getMeanTimeSec(); 177 } 178 getCompileWithoutCacheMeanTimeSec()179 public float getCompileWithoutCacheMeanTimeSec() { 180 return mLatencyCompileWithoutCache == null ? 0.0f 181 : mLatencyCompileWithoutCache.getMeanTimeSec(); 182 } 183 getSaveToCacheMeanTimeSec()184 public float getSaveToCacheMeanTimeSec() { 185 return mLatencySaveToCache == null ? 0.0f : mLatencySaveToCache.getMeanTimeSec(); 186 } 187 getPrepareFromCacheMeanTimeSec()188 public float getPrepareFromCacheMeanTimeSec() { 189 return mLatencyPrepareFromCache == null ? 0.0f : mLatencyPrepareFromCache.getMeanTimeSec(); 190 } 191 getEvaluatorResults()192 public List<Pair<String, Float>> getEvaluatorResults() { 193 List<Pair<String, Float>> results = new ArrayList<>(); 194 for (int i = 0; i < mEvaluatorKeys.length; ++i) { 195 results.add(new Pair<>(mEvaluatorKeys[i], mEvaluatorResults[i])); 196 } 197 return results; 198 } 199 200 @Override toString()201 public String toString() { 202 if (!TextUtils.isEmpty(mBenchmarkError)) { 203 return mBenchmarkError; 204 } 205 206 StringBuilder result = new StringBuilder("BenchmarkResult{" + 207 "mTestInfo='" + mTestInfo + '\'' + 208 ", mLatencyInference=" + mLatencyInference.toString() + 209 ", mSumOfMSEs=" + mSumOfMSEs + 210 ", mMaxSingleErrors=" + mMaxSingleError + 211 ", mBackendType=" + mBackendType + 212 ", mTestSetSize=" + mTestSetSize); 213 for (int i = 0; i < mEvaluatorKeys.length; i++) { 214 result.append(", ").append(mEvaluatorKeys[i]).append("=").append(mEvaluatorResults[i]); 215 } 216 217 result.append(", mValidationErrors=["); 218 for (int i = 0; i < mValidationErrors.length; i++) { 219 result.append(mValidationErrors[i]); 220 if (i < mValidationErrors.length - 1) { 221 result.append(","); 222 } 223 } 224 result.append("]"); 225 226 if (mLatencyCompileWithoutCache != null) { 227 result.append(", mLatencyCompileWithoutCache=") 228 .append(mLatencyCompileWithoutCache.toString()); 229 } 230 if (mLatencySaveToCache != null) { 231 result.append(", mLatencySaveToCache=").append(mLatencySaveToCache.toString()); 232 } 233 if (mLatencyPrepareFromCache != null) { 234 result.append(", mLatencyPrepareFromCache=") 235 .append(mLatencyPrepareFromCache.toString()); 236 } 237 result.append(", mCompilationCacheSizeBytes=").append(mCompilationCacheSizeBytes); 238 239 result.append('}'); 240 return result.toString(); 241 } 242 hasBenchmarkError()243 public boolean hasBenchmarkError() { 244 return !TextUtils.isEmpty(mBenchmarkError); 245 } 246 getBenchmarkError()247 public String getBenchmarkError() { 248 if (!hasBenchmarkError()) return null; 249 250 return mBenchmarkError; 251 } 252 setBenchmarkError(String benchmarkError)253 public void setBenchmarkError(String benchmarkError) { 254 mBenchmarkError = benchmarkError; 255 } 256 getSummary(float baselineSec)257 public String getSummary(float baselineSec) { 258 if (hasBenchmarkError()) { 259 return getBenchmarkError(); 260 } 261 return mLatencyInference.getSummary(baselineSec); 262 } 263 toBundle(String testName)264 public Bundle toBundle(String testName) { 265 Bundle results = new Bundle(); 266 if (!TextUtils.isEmpty(mBenchmarkError)) { 267 results.putString(testName + "_error", mBenchmarkError); 268 return results; 269 } 270 271 mLatencyInference.putToBundle(results, testName + "_inference"); 272 results.putFloat(testName + "_inference_mean_square_error", 273 mSumOfMSEs / mLatencyInference.getIterations()); 274 results.putFloat(testName + "_inference_max_single_error", mMaxSingleError); 275 for (int i = 0; i < mEvaluatorKeys.length; i++) { 276 results.putFloat(testName + "_inference_" + mEvaluatorKeys[i], mEvaluatorResults[i]); 277 } 278 if (mLatencyCompileWithoutCache != null) { 279 mLatencyCompileWithoutCache.putToBundle(results, testName + "_compile_without_cache"); 280 } 281 if (mLatencySaveToCache != null) { 282 mLatencySaveToCache.putToBundle(results, testName + "_save_to_cache"); 283 } 284 if (mLatencyPrepareFromCache != null) { 285 mLatencyPrepareFromCache.putToBundle(results, testName + "_prepare_from_cache"); 286 } 287 if (mCompilationCacheSizeBytes > 0) { 288 results.putInt(testName + "_compilation_cache_size", mCompilationCacheSizeBytes); 289 } 290 return results; 291 } 292 293 @SuppressWarnings("AndroidJdkLibsChecker") toCsvLine()294 public String toCsvLine() { 295 if (!TextUtils.isEmpty(mBenchmarkError)) { 296 return ""; 297 } 298 299 StringBuilder sb = new StringBuilder(); 300 sb.append(mTestInfo).append(',').append(mBackendType); 301 302 mLatencyInference.appendToCsvLine(sb); 303 304 sb.append(',').append(String.join(",", 305 String.valueOf(mMaxSingleError), 306 String.valueOf(mTestSetSize), 307 String.valueOf(mEvaluatorKeys.length), 308 String.valueOf(mValidationErrors.length))); 309 310 for (int i = 0; i < mEvaluatorKeys.length; ++i) { 311 sb.append(',').append(mEvaluatorKeys[i]); 312 } 313 314 for (int i = 0; i < mEvaluatorKeys.length; ++i) { 315 sb.append(',').append(mEvaluatorResults[i]); 316 } 317 318 for (String validationError : mValidationErrors) { 319 sb.append(',').append(validationError.replace(',', ' ')); 320 } 321 322 sb.append(',').append(mLatencyCompileWithoutCache != null); 323 if (mLatencyCompileWithoutCache != null) { 324 mLatencyCompileWithoutCache.appendToCsvLine(sb); 325 } 326 sb.append(',').append(mLatencySaveToCache != null); 327 if (mLatencySaveToCache != null) { 328 mLatencySaveToCache.appendToCsvLine(sb); 329 } 330 sb.append(',').append(mLatencyPrepareFromCache != null); 331 if (mLatencyPrepareFromCache != null) { 332 mLatencyPrepareFromCache.appendToCsvLine(sb); 333 } 334 sb.append(',').append(mCompilationCacheSizeBytes); 335 336 sb.append('\n'); 337 return sb.toString(); 338 } 339 fromInferenceResults( String testInfo, String backendType, List<InferenceInOutSequence> inferenceInOuts, List<InferenceResult> inferenceResults, EvaluatorInterface evaluator)340 public static BenchmarkResult fromInferenceResults( 341 String testInfo, 342 String backendType, 343 List<InferenceInOutSequence> inferenceInOuts, 344 List<InferenceResult> inferenceResults, 345 EvaluatorInterface evaluator) { 346 float[] latencies = new float[inferenceResults.size()]; 347 float sumOfMSEs = 0; 348 float maxSingleError = 0; 349 for (int i = 0; i < inferenceResults.size(); i++) { 350 InferenceResult iresult = inferenceResults.get(i); 351 latencies[i] = iresult.mComputeTimeSec; 352 if (iresult.mMeanSquaredErrors != null) { 353 for (float mse : iresult.mMeanSquaredErrors) { 354 sumOfMSEs += mse; 355 } 356 } 357 if (iresult.mMaxSingleErrors != null) { 358 for (float mse : iresult.mMaxSingleErrors) { 359 if (mse > maxSingleError) { 360 maxSingleError = mse; 361 } 362 } 363 } 364 } 365 366 String[] evaluatorKeys = null; 367 float[] evaluatorResults = null; 368 String[] validationErrors = null; 369 if (evaluator != null) { 370 ArrayList<String> keys = new ArrayList<String>(); 371 ArrayList<Float> results = new ArrayList<Float>(); 372 ArrayList<String> validationErrorsList = new ArrayList<>(); 373 evaluator.EvaluateAccuracy(inferenceInOuts, inferenceResults, keys, results, 374 validationErrorsList); 375 evaluatorKeys = new String[keys.size()]; 376 evaluatorKeys = keys.toArray(evaluatorKeys); 377 evaluatorResults = new float[results.size()]; 378 for (int i = 0; i < evaluatorResults.length; i++) { 379 evaluatorResults[i] = results.get(i).floatValue(); 380 } 381 validationErrors = new String[validationErrorsList.size()]; 382 validationErrorsList.toArray(validationErrors); 383 } 384 385 // Calc test set size 386 int testSetSize = 0; 387 for (InferenceInOutSequence iios : inferenceInOuts) { 388 testSetSize += iios.size(); 389 } 390 391 return new BenchmarkResult(new LatencyResult(latencies), sumOfMSEs, maxSingleError, 392 testInfo, evaluatorKeys, evaluatorResults, backendType, testSetSize, 393 validationErrors); 394 } 395 setCompilationBenchmarkResult(CompilationBenchmarkResult result)396 public void setCompilationBenchmarkResult(CompilationBenchmarkResult result) { 397 mLatencyCompileWithoutCache = new LatencyResult(result.mCompileWithoutCacheTimeSec); 398 if (result.mSaveToCacheTimeSec != null) { 399 mLatencySaveToCache = new LatencyResult(result.mSaveToCacheTimeSec); 400 } 401 if (result.mPrepareFromCacheTimeSec != null) { 402 mLatencyPrepareFromCache = new LatencyResult(result.mPrepareFromCacheTimeSec); 403 } 404 mCompilationCacheSizeBytes = result.mCacheSizeBytes; 405 } 406 } 407