1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.os.Bundle; 20 import android.os.Parcel; 21 import android.os.Parcelable; 22 import android.text.TextUtils; 23 import android.util.Pair; 24 25 import java.util.ArrayList; 26 import java.util.Arrays; 27 import java.util.List; 28 29 public class BenchmarkResult implements Parcelable { 30 public final static String BACKEND_TFLITE_NNAPI = "TFLite_NNAPI"; 31 public final static String BACKEND_TFLITE_CPU = "TFLite_CPU"; 32 33 private final static int TIME_FREQ_ARRAY_SIZE = 32; 34 35 private float mTotalTimeSec; 36 private float mSumOfMSEs; 37 private float mMaxSingleError; 38 private int mIterations; 39 private float mTimeStdDeviation; 40 private String mTestInfo; 41 private int mNumberOfEvaluatorResults; 42 private String[] mEvaluatorKeys = {}; 43 private float[] mEvaluatorResults = {}; 44 45 /** Type of backend used for inference */ 46 private String mBackendType; 47 48 /** Time offset for inference frequency counts */ 49 private float mTimeFreqStartSec; 50 51 /** Index time offset for inference frequency counts */ 52 private float mTimeFreqStepSec; 53 54 /** 55 * Array of inference frequency counts. 56 * Each entry contains inference count for time range: 57 * [mTimeFreqStartSec + i*mTimeFreqStepSec, mTimeFreqStartSec + (1+i*mTimeFreqStepSec) 58 */ 59 private float[] mTimeFreqSec = {}; 60 61 /** Size of test set using for inference */ 62 private int mTestSetSize; 63 64 /** List of validation errors */ 65 private String[] mValidationErrors = {}; 66 67 /** Error that prevents the benchmark from running, e.g. SDK version not supported. */ 68 private String mBenchmarkError; 69 BenchmarkResult(float totalTimeSec, int iterations, float timeVarianceSec, float sumOfMSEs, float maxSingleError, String testInfo, String[] evaluatorKeys, float[] evaluatorResults, float timeFreqStartSec, float timeFreqStepSec, float[] timeFreqSec, String backendType, int testSetSize, String[] validationErrors)70 public BenchmarkResult(float totalTimeSec, int iterations, float timeVarianceSec, 71 float sumOfMSEs, float maxSingleError, String testInfo, 72 String[] evaluatorKeys, float[] evaluatorResults, 73 float timeFreqStartSec, float timeFreqStepSec, float[] timeFreqSec, 74 String backendType, int testSetSize, String[] validationErrors) { 75 mTotalTimeSec = totalTimeSec; 76 mSumOfMSEs = sumOfMSEs; 77 mMaxSingleError = maxSingleError; 78 mIterations = iterations; 79 mTimeStdDeviation = timeVarianceSec; 80 mTestInfo = testInfo; 81 mTimeFreqStartSec = timeFreqStartSec; 82 mTimeFreqStepSec = timeFreqStepSec; 83 mTimeFreqSec = timeFreqSec; 84 mBackendType = backendType; 85 mTestSetSize = testSetSize; 86 if (validationErrors == null) { 87 mValidationErrors = new String[0]; 88 } else { 89 mValidationErrors = validationErrors; 90 } 91 92 if (evaluatorKeys == null) { 93 mEvaluatorKeys = new String[0]; 94 } else { 95 mEvaluatorKeys = evaluatorKeys; 96 } 97 if (evaluatorResults == null) { 98 mEvaluatorResults = new float[0]; 99 } else { 100 mEvaluatorResults = evaluatorResults; 101 } 102 if (mEvaluatorResults.length != mEvaluatorKeys.length) { 103 throw new IllegalArgumentException("Different number of evaluator keys vs values"); 104 } 105 mNumberOfEvaluatorResults = mEvaluatorResults.length; 106 } 107 BenchmarkResult(String benchmarkError)108 public BenchmarkResult(String benchmarkError) { 109 mBenchmarkError = benchmarkError; 110 } 111 hasValidationErrors()112 public boolean hasValidationErrors() { 113 return mValidationErrors.length > 0; 114 } 115 BenchmarkResult(Parcel in)116 protected BenchmarkResult(Parcel in) { 117 mTotalTimeSec = in.readFloat(); 118 mSumOfMSEs = in.readFloat(); 119 mMaxSingleError = in.readFloat(); 120 mIterations = in.readInt(); 121 mTimeStdDeviation = in.readFloat(); 122 mTestInfo = in.readString(); 123 mNumberOfEvaluatorResults = in.readInt(); 124 mEvaluatorKeys = new String[mNumberOfEvaluatorResults]; 125 in.readStringArray(mEvaluatorKeys); 126 mEvaluatorResults = new float[mNumberOfEvaluatorResults]; 127 in.readFloatArray(mEvaluatorResults); 128 if (mEvaluatorResults.length != mEvaluatorKeys.length) { 129 throw new IllegalArgumentException("Different number of evaluator keys vs values"); 130 } 131 mTimeFreqStartSec = in.readFloat(); 132 mTimeFreqStepSec = in.readFloat(); 133 int timeFreqSecLength = in.readInt(); 134 mTimeFreqSec = new float[timeFreqSecLength]; 135 in.readFloatArray(mTimeFreqSec); 136 mBackendType = in.readString(); 137 mTestSetSize = in.readInt(); 138 int validationsErrorsSize = in.readInt(); 139 mValidationErrors = new String[validationsErrorsSize]; 140 in.readStringArray(mValidationErrors); 141 mBenchmarkError = in.readString(); 142 } 143 144 @Override describeContents()145 public int describeContents() { 146 return 0; 147 } 148 149 @Override writeToParcel(Parcel dest, int flags)150 public void writeToParcel(Parcel dest, int flags) { 151 dest.writeFloat(mTotalTimeSec); 152 dest.writeFloat(mSumOfMSEs); 153 dest.writeFloat(mMaxSingleError); 154 dest.writeInt(mIterations); 155 dest.writeFloat(mTimeStdDeviation); 156 dest.writeString(mTestInfo); 157 dest.writeInt(mNumberOfEvaluatorResults); 158 dest.writeStringArray(mEvaluatorKeys); 159 dest.writeFloatArray(mEvaluatorResults); 160 dest.writeFloat(mTimeFreqStartSec); 161 dest.writeFloat(mTimeFreqStepSec); 162 dest.writeInt(mTimeFreqSec.length); 163 dest.writeFloatArray(mTimeFreqSec); 164 dest.writeString(mBackendType); 165 dest.writeInt(mTestSetSize); 166 dest.writeInt(mValidationErrors.length); 167 dest.writeStringArray(mValidationErrors); 168 dest.writeString(mBenchmarkError); 169 } 170 171 @SuppressWarnings("unused") 172 public static final Parcelable.Creator<BenchmarkResult> CREATOR = 173 new Parcelable.Creator<BenchmarkResult>() { 174 @Override 175 public BenchmarkResult createFromParcel(Parcel in) { 176 return new BenchmarkResult(in); 177 } 178 179 @Override 180 public BenchmarkResult[] newArray(int size) { 181 return new BenchmarkResult[size]; 182 } 183 }; 184 getError()185 public float getError() { 186 return mSumOfMSEs; 187 } 188 getMeanTimeSec()189 public float getMeanTimeSec() { 190 return mTotalTimeSec / mIterations; 191 } 192 getEvaluatorResults()193 public List<Pair<String, Float>> getEvaluatorResults() { 194 List<Pair<String, Float>> results = new ArrayList<>(); 195 for (int i = 0; i < mEvaluatorKeys.length; ++i) { 196 results.add(new Pair<>(mEvaluatorKeys[i], mEvaluatorResults[i])); 197 } 198 return results; 199 } 200 201 @Override toString()202 public String toString() { 203 if (!TextUtils.isEmpty(mBenchmarkError)) { 204 return mBenchmarkError; 205 } 206 207 StringBuilder result = new StringBuilder("BenchmarkResult{" + 208 "mTestInfo='" + mTestInfo + '\'' + 209 ", getMeanTimeSec()=" + getMeanTimeSec() + 210 ", mTotalTimeSec=" + mTotalTimeSec + 211 ", mSumOfMSEs=" + mSumOfMSEs + 212 ", mMaxSingleErrors=" + mMaxSingleError + 213 ", mIterations=" + mIterations + 214 ", mTimeStdDeviation=" + mTimeStdDeviation + 215 ", mTimeFreqStartSec=" + mTimeFreqStartSec + 216 ", mTimeFreqStepSec=" + mTimeFreqStepSec + 217 ", mBackendType=" + mBackendType + 218 ", mTestSetSize=" + mTestSetSize); 219 for (int i = 0; i < mEvaluatorKeys.length; i++) { 220 result.append(", ").append(mEvaluatorKeys[i]).append("=").append(mEvaluatorResults[i]); 221 } 222 223 result.append(", mValidationErrors=["); 224 for (int i = 0; i < mValidationErrors.length; i++) { 225 result.append(mValidationErrors[i]); 226 if (i < mValidationErrors.length - 1) { 227 result.append(","); 228 } 229 } 230 result.append("]"); 231 result.append('}'); 232 return result.toString(); 233 } 234 getSummary(float baselineSec)235 public String getSummary(float baselineSec) { 236 if (!TextUtils.isEmpty(mBenchmarkError)) { 237 return mBenchmarkError; 238 } 239 240 java.text.DecimalFormat df = new java.text.DecimalFormat("######.##"); 241 242 return df.format(rebase(getMeanTimeSec(), baselineSec)) + 243 "X, n=" + mIterations + ", μ=" + df.format(getMeanTimeSec() * 1000.0) 244 + "ms, σ=" + df.format(mTimeStdDeviation * 1000.0) + "ms"; 245 } 246 toBundle(String testName)247 public Bundle toBundle(String testName) { 248 Bundle results = new Bundle(); 249 if (!TextUtils.isEmpty(mBenchmarkError)) { 250 results.putString(testName + "_error", mBenchmarkError); 251 return results; 252 } 253 254 // Reported in ms 255 results.putFloat(testName + "_avg", getMeanTimeSec() * 1000.0f); 256 results.putFloat(testName + "_std_dev", mTimeStdDeviation * 1000.0f); 257 results.putFloat(testName + "_total_time", mTotalTimeSec * 1000.0f); 258 results.putFloat(testName + "_mean_square_error", mSumOfMSEs / mIterations); 259 results.putFloat(testName + "_max_single_error", mMaxSingleError); 260 results.putInt(testName + "_iterations", mIterations); 261 for (int i = 0; i < mEvaluatorKeys.length; i++) { 262 results.putFloat(testName + "_" + mEvaluatorKeys[i], 263 mEvaluatorResults[i]); 264 } 265 return results; 266 } 267 268 @SuppressWarnings("AndroidJdkLibsChecker") toCsvLine()269 public String toCsvLine() { 270 if (!TextUtils.isEmpty(mBenchmarkError)) { 271 return ""; 272 } 273 274 StringBuilder sb = new StringBuilder(); 275 sb.append(String.join(",", 276 mTestInfo, 277 mBackendType, 278 String.valueOf(mIterations), 279 String.valueOf(mTotalTimeSec), 280 String.valueOf(mMaxSingleError), 281 String.valueOf(mTestSetSize), 282 String.valueOf(mTimeFreqStartSec), 283 String.valueOf(mTimeFreqStepSec), 284 String.valueOf(mEvaluatorKeys.length), 285 String.valueOf(mTimeFreqSec.length), 286 String.valueOf(mValidationErrors.length))); 287 288 for (int i = 0; i < mEvaluatorKeys.length; ++i) { 289 sb.append(',').append(mEvaluatorKeys[i]); 290 } 291 292 for (int i = 0; i < mEvaluatorKeys.length; ++i) { 293 sb.append(',').append(mEvaluatorResults[i]); 294 } 295 296 for (float value : mTimeFreqSec) { 297 sb.append(',').append(value); 298 } 299 300 for (String validationError : mValidationErrors) { 301 sb.append(',').append(validationError.replace(',', ' ')); 302 } 303 304 sb.append('\n'); 305 return sb.toString(); 306 } 307 rebase(float v, float baselineSec)308 float rebase(float v, float baselineSec) { 309 if (v > 0.001) { 310 v = baselineSec / v; 311 } 312 return v; 313 } 314 fromInferenceResults( String testInfo, String backendType, List<InferenceInOutSequence> inferenceInOuts, List<InferenceResult> inferenceResults, EvaluatorInterface evaluator)315 public static BenchmarkResult fromInferenceResults( 316 String testInfo, 317 String backendType, 318 List<InferenceInOutSequence> inferenceInOuts, 319 List<InferenceResult> inferenceResults, 320 EvaluatorInterface evaluator) { 321 float totalTime = 0; 322 int iterations = 0; 323 float sumOfMSEs = 0; 324 float maxSingleError = 0; 325 326 float maxComputeTimeSec = 0.0f; 327 float minComputeTimeSec = Float.MAX_VALUE; 328 329 for (InferenceResult iresult : inferenceResults) { 330 iterations++; 331 totalTime += iresult.mComputeTimeSec; 332 if (iresult.mMeanSquaredErrors != null) { 333 for (float mse : iresult.mMeanSquaredErrors) { 334 sumOfMSEs += mse; 335 } 336 } 337 if (iresult.mMaxSingleErrors != null) { 338 for (float mse : iresult.mMaxSingleErrors) { 339 if (mse > maxSingleError) { 340 maxSingleError = mse; 341 } 342 } 343 } 344 345 if (maxComputeTimeSec < iresult.mComputeTimeSec) { 346 maxComputeTimeSec = iresult.mComputeTimeSec; 347 } 348 if (minComputeTimeSec > iresult.mComputeTimeSec) { 349 minComputeTimeSec = iresult.mComputeTimeSec; 350 } 351 } 352 353 float inferenceMean = (totalTime / iterations); 354 355 float variance = 0.0f; 356 for (InferenceResult iresult : inferenceResults) { 357 float v = (iresult.mComputeTimeSec - inferenceMean); 358 variance += v * v; 359 } 360 variance /= iterations; 361 String[] evaluatorKeys = null; 362 float[] evaluatorResults = null; 363 String[] validationErrors = null; 364 if (evaluator != null) { 365 ArrayList<String> keys = new ArrayList<String>(); 366 ArrayList<Float> results = new ArrayList<Float>(); 367 ArrayList<String> validationErrorsList = new ArrayList<>(); 368 evaluator.EvaluateAccuracy(inferenceInOuts, inferenceResults, keys, results, 369 validationErrorsList); 370 evaluatorKeys = new String[keys.size()]; 371 evaluatorKeys = keys.toArray(evaluatorKeys); 372 evaluatorResults = new float[results.size()]; 373 for (int i = 0; i < evaluatorResults.length; i++) { 374 evaluatorResults[i] = results.get(i).floatValue(); 375 } 376 validationErrors = new String[validationErrorsList.size()]; 377 validationErrorsList.toArray(validationErrors); 378 } 379 380 // Calculate inference frequency/histogram across TIME_FREQ_ARRAY_SIZE buckets. 381 float[] timeFreqSec = new float[TIME_FREQ_ARRAY_SIZE]; 382 float stepSize = (maxComputeTimeSec - minComputeTimeSec) / (TIME_FREQ_ARRAY_SIZE - 1); 383 for (InferenceResult iresult : inferenceResults) { 384 timeFreqSec[(int) ((iresult.mComputeTimeSec - minComputeTimeSec) / stepSize)] += 1; 385 } 386 387 // Calc test set size 388 int testSetSize = 0; 389 for (InferenceInOutSequence iios : inferenceInOuts) { 390 testSetSize += iios.size(); 391 } 392 393 return new BenchmarkResult(totalTime, iterations, (float) Math.sqrt(variance), 394 sumOfMSEs, maxSingleError, testInfo, evaluatorKeys, evaluatorResults, 395 minComputeTimeSec, stepSize, timeFreqSec, backendType, testSetSize, 396 validationErrors); 397 } 398 } 399