1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import android.os.Bundle;
20 import android.os.Parcel;
21 import android.os.Parcelable;
22 import android.text.TextUtils;
23 import android.util.Pair;
24 
25 import java.util.ArrayList;
26 import java.util.Arrays;
27 import java.util.List;
28 
29 public class BenchmarkResult implements Parcelable {
30     public final static String BACKEND_TFLITE_NNAPI = "TFLite_NNAPI";
31     public final static String BACKEND_TFLITE_CPU = "TFLite_CPU";
32 
33     private final static int TIME_FREQ_ARRAY_SIZE = 32;
34 
35     private float mTotalTimeSec;
36     private float mSumOfMSEs;
37     private float mMaxSingleError;
38     private int mIterations;
39     private float mTimeStdDeviation;
40     private String mTestInfo;
41     private int mNumberOfEvaluatorResults;
42     private String[] mEvaluatorKeys = {};
43     private float[] mEvaluatorResults = {};
44 
45     /** Type of backend used for inference */
46     private String mBackendType;
47 
48     /** Time offset for inference frequency counts */
49     private float mTimeFreqStartSec;
50 
51     /** Index time offset for inference frequency counts */
52     private float mTimeFreqStepSec;
53 
54     /**
55      * Array of inference frequency counts.
56      * Each entry contains inference count for time range:
57      * [mTimeFreqStartSec + i*mTimeFreqStepSec, mTimeFreqStartSec + (1+i*mTimeFreqStepSec)
58      */
59     private float[] mTimeFreqSec = {};
60 
61     /** Size of test set using for inference */
62     private int mTestSetSize;
63 
64     /** List of validation errors */
65     private String[] mValidationErrors = {};
66 
67     /** Error that prevents the benchmark from running, e.g. SDK version not supported. */
68     private String mBenchmarkError;
69 
BenchmarkResult(float totalTimeSec, int iterations, float timeVarianceSec, float sumOfMSEs, float maxSingleError, String testInfo, String[] evaluatorKeys, float[] evaluatorResults, float timeFreqStartSec, float timeFreqStepSec, float[] timeFreqSec, String backendType, int testSetSize, String[] validationErrors)70     public BenchmarkResult(float totalTimeSec, int iterations, float timeVarianceSec,
71             float sumOfMSEs, float maxSingleError, String testInfo,
72             String[] evaluatorKeys, float[] evaluatorResults,
73             float timeFreqStartSec, float timeFreqStepSec, float[] timeFreqSec,
74             String backendType, int testSetSize, String[] validationErrors) {
75         mTotalTimeSec = totalTimeSec;
76         mSumOfMSEs = sumOfMSEs;
77         mMaxSingleError = maxSingleError;
78         mIterations = iterations;
79         mTimeStdDeviation = timeVarianceSec;
80         mTestInfo = testInfo;
81         mTimeFreqStartSec = timeFreqStartSec;
82         mTimeFreqStepSec = timeFreqStepSec;
83         mTimeFreqSec = timeFreqSec;
84         mBackendType = backendType;
85         mTestSetSize = testSetSize;
86         if (validationErrors == null) {
87             mValidationErrors = new String[0];
88         } else {
89             mValidationErrors = validationErrors;
90         }
91 
92         if (evaluatorKeys == null) {
93             mEvaluatorKeys = new String[0];
94         } else {
95             mEvaluatorKeys = evaluatorKeys;
96         }
97         if (evaluatorResults == null) {
98             mEvaluatorResults = new float[0];
99         } else {
100             mEvaluatorResults = evaluatorResults;
101         }
102         if (mEvaluatorResults.length != mEvaluatorKeys.length) {
103             throw new IllegalArgumentException("Different number of evaluator keys vs values");
104         }
105         mNumberOfEvaluatorResults = mEvaluatorResults.length;
106     }
107 
BenchmarkResult(String benchmarkError)108     public BenchmarkResult(String benchmarkError) {
109         mBenchmarkError = benchmarkError;
110     }
111 
hasValidationErrors()112     public boolean hasValidationErrors() {
113         return mValidationErrors.length > 0;
114     }
115 
BenchmarkResult(Parcel in)116     protected BenchmarkResult(Parcel in) {
117         mTotalTimeSec = in.readFloat();
118         mSumOfMSEs = in.readFloat();
119         mMaxSingleError = in.readFloat();
120         mIterations = in.readInt();
121         mTimeStdDeviation = in.readFloat();
122         mTestInfo = in.readString();
123         mNumberOfEvaluatorResults = in.readInt();
124         mEvaluatorKeys = new String[mNumberOfEvaluatorResults];
125         in.readStringArray(mEvaluatorKeys);
126         mEvaluatorResults = new float[mNumberOfEvaluatorResults];
127         in.readFloatArray(mEvaluatorResults);
128         if (mEvaluatorResults.length != mEvaluatorKeys.length) {
129             throw new IllegalArgumentException("Different number of evaluator keys vs values");
130         }
131         mTimeFreqStartSec = in.readFloat();
132         mTimeFreqStepSec = in.readFloat();
133         int timeFreqSecLength = in.readInt();
134         mTimeFreqSec = new float[timeFreqSecLength];
135         in.readFloatArray(mTimeFreqSec);
136         mBackendType = in.readString();
137         mTestSetSize = in.readInt();
138         int validationsErrorsSize = in.readInt();
139         mValidationErrors = new String[validationsErrorsSize];
140         in.readStringArray(mValidationErrors);
141         mBenchmarkError = in.readString();
142     }
143 
144     @Override
describeContents()145     public int describeContents() {
146         return 0;
147     }
148 
149     @Override
writeToParcel(Parcel dest, int flags)150     public void writeToParcel(Parcel dest, int flags) {
151         dest.writeFloat(mTotalTimeSec);
152         dest.writeFloat(mSumOfMSEs);
153         dest.writeFloat(mMaxSingleError);
154         dest.writeInt(mIterations);
155         dest.writeFloat(mTimeStdDeviation);
156         dest.writeString(mTestInfo);
157         dest.writeInt(mNumberOfEvaluatorResults);
158         dest.writeStringArray(mEvaluatorKeys);
159         dest.writeFloatArray(mEvaluatorResults);
160         dest.writeFloat(mTimeFreqStartSec);
161         dest.writeFloat(mTimeFreqStepSec);
162         dest.writeInt(mTimeFreqSec.length);
163         dest.writeFloatArray(mTimeFreqSec);
164         dest.writeString(mBackendType);
165         dest.writeInt(mTestSetSize);
166         dest.writeInt(mValidationErrors.length);
167         dest.writeStringArray(mValidationErrors);
168         dest.writeString(mBenchmarkError);
169     }
170 
171     @SuppressWarnings("unused")
172     public static final Parcelable.Creator<BenchmarkResult> CREATOR =
173             new Parcelable.Creator<BenchmarkResult>() {
174                 @Override
175                 public BenchmarkResult createFromParcel(Parcel in) {
176                     return new BenchmarkResult(in);
177                 }
178 
179                 @Override
180                 public BenchmarkResult[] newArray(int size) {
181                     return new BenchmarkResult[size];
182                 }
183             };
184 
getError()185     public float getError() {
186         return mSumOfMSEs;
187     }
188 
getMeanTimeSec()189     public float getMeanTimeSec() {
190         return mTotalTimeSec / mIterations;
191     }
192 
getEvaluatorResults()193     public List<Pair<String, Float>> getEvaluatorResults() {
194         List<Pair<String, Float>> results = new ArrayList<>();
195         for (int i = 0; i < mEvaluatorKeys.length; ++i) {
196             results.add(new Pair<>(mEvaluatorKeys[i], mEvaluatorResults[i]));
197         }
198         return results;
199     }
200 
201     @Override
toString()202     public String toString() {
203         if (!TextUtils.isEmpty(mBenchmarkError)) {
204             return mBenchmarkError;
205         }
206 
207         StringBuilder result = new StringBuilder("BenchmarkResult{" +
208                 "mTestInfo='" + mTestInfo + '\'' +
209                 ", getMeanTimeSec()=" + getMeanTimeSec() +
210                 ", mTotalTimeSec=" + mTotalTimeSec +
211                 ", mSumOfMSEs=" + mSumOfMSEs +
212                 ", mMaxSingleErrors=" + mMaxSingleError +
213                 ", mIterations=" + mIterations +
214                 ", mTimeStdDeviation=" + mTimeStdDeviation +
215                 ", mTimeFreqStartSec=" + mTimeFreqStartSec +
216                 ", mTimeFreqStepSec=" + mTimeFreqStepSec +
217                 ", mBackendType=" + mBackendType +
218                 ", mTestSetSize=" + mTestSetSize);
219         for (int i = 0; i < mEvaluatorKeys.length; i++) {
220             result.append(", ").append(mEvaluatorKeys[i]).append("=").append(mEvaluatorResults[i]);
221         }
222 
223         result.append(", mValidationErrors=[");
224         for (int i = 0; i < mValidationErrors.length; i++) {
225             result.append(mValidationErrors[i]);
226             if (i < mValidationErrors.length - 1) {
227                 result.append(",");
228             }
229         }
230         result.append("]");
231         result.append('}');
232         return result.toString();
233     }
234 
getSummary(float baselineSec)235     public String getSummary(float baselineSec) {
236         if (!TextUtils.isEmpty(mBenchmarkError)) {
237             return mBenchmarkError;
238         }
239 
240         java.text.DecimalFormat df = new java.text.DecimalFormat("######.##");
241 
242         return df.format(rebase(getMeanTimeSec(), baselineSec)) +
243             "X, n=" + mIterations + ", μ=" + df.format(getMeanTimeSec() * 1000.0)
244             + "ms, σ=" + df.format(mTimeStdDeviation * 1000.0) + "ms";
245     }
246 
toBundle(String testName)247     public Bundle toBundle(String testName) {
248         Bundle results = new Bundle();
249         if (!TextUtils.isEmpty(mBenchmarkError)) {
250             results.putString(testName + "_error", mBenchmarkError);
251             return results;
252         }
253 
254         // Reported in ms
255         results.putFloat(testName + "_avg", getMeanTimeSec() * 1000.0f);
256         results.putFloat(testName + "_std_dev", mTimeStdDeviation * 1000.0f);
257         results.putFloat(testName + "_total_time", mTotalTimeSec * 1000.0f);
258         results.putFloat(testName + "_mean_square_error", mSumOfMSEs / mIterations);
259         results.putFloat(testName + "_max_single_error", mMaxSingleError);
260         results.putInt(testName + "_iterations", mIterations);
261         for (int i = 0; i < mEvaluatorKeys.length; i++) {
262             results.putFloat(testName + "_" + mEvaluatorKeys[i],
263                 mEvaluatorResults[i]);
264         }
265         return results;
266     }
267 
268     @SuppressWarnings("AndroidJdkLibsChecker")
toCsvLine()269     public String toCsvLine() {
270         if (!TextUtils.isEmpty(mBenchmarkError)) {
271             return "";
272         }
273 
274         StringBuilder sb = new StringBuilder();
275         sb.append(String.join(",",
276             mTestInfo,
277             mBackendType,
278             String.valueOf(mIterations),
279             String.valueOf(mTotalTimeSec),
280             String.valueOf(mMaxSingleError),
281             String.valueOf(mTestSetSize),
282             String.valueOf(mTimeFreqStartSec),
283             String.valueOf(mTimeFreqStepSec),
284             String.valueOf(mEvaluatorKeys.length),
285             String.valueOf(mTimeFreqSec.length),
286             String.valueOf(mValidationErrors.length)));
287 
288         for (int i = 0; i < mEvaluatorKeys.length; ++i) {
289             sb.append(',').append(mEvaluatorKeys[i]);
290         }
291 
292         for (int i = 0; i < mEvaluatorKeys.length; ++i) {
293             sb.append(',').append(mEvaluatorResults[i]);
294         }
295 
296         for (float value : mTimeFreqSec) {
297             sb.append(',').append(value);
298         }
299 
300         for (String validationError : mValidationErrors) {
301             sb.append(',').append(validationError.replace(',', ' '));
302         }
303 
304         sb.append('\n');
305         return sb.toString();
306     }
307 
rebase(float v, float baselineSec)308     float rebase(float v, float baselineSec) {
309         if (v > 0.001) {
310             v = baselineSec / v;
311         }
312         return v;
313     }
314 
fromInferenceResults( String testInfo, String backendType, List<InferenceInOutSequence> inferenceInOuts, List<InferenceResult> inferenceResults, EvaluatorInterface evaluator)315     public static BenchmarkResult fromInferenceResults(
316             String testInfo,
317             String backendType,
318             List<InferenceInOutSequence> inferenceInOuts,
319             List<InferenceResult> inferenceResults,
320             EvaluatorInterface evaluator) {
321         float totalTime = 0;
322         int iterations = 0;
323         float sumOfMSEs = 0;
324         float maxSingleError = 0;
325 
326         float maxComputeTimeSec = 0.0f;
327         float minComputeTimeSec = Float.MAX_VALUE;
328 
329         for (InferenceResult iresult : inferenceResults) {
330             iterations++;
331             totalTime += iresult.mComputeTimeSec;
332             if (iresult.mMeanSquaredErrors != null) {
333                 for (float mse : iresult.mMeanSquaredErrors) {
334                     sumOfMSEs += mse;
335                 }
336             }
337             if (iresult.mMaxSingleErrors != null) {
338                 for (float mse : iresult.mMaxSingleErrors) {
339                     if (mse > maxSingleError) {
340                         maxSingleError = mse;
341                     }
342                 }
343             }
344 
345             if (maxComputeTimeSec < iresult.mComputeTimeSec) {
346                 maxComputeTimeSec = iresult.mComputeTimeSec;
347             }
348             if (minComputeTimeSec > iresult.mComputeTimeSec) {
349                 minComputeTimeSec = iresult.mComputeTimeSec;
350             }
351         }
352 
353         float inferenceMean = (totalTime / iterations);
354 
355         float variance = 0.0f;
356         for (InferenceResult iresult : inferenceResults) {
357             float v = (iresult.mComputeTimeSec - inferenceMean);
358             variance += v * v;
359         }
360         variance /= iterations;
361         String[] evaluatorKeys = null;
362         float[] evaluatorResults = null;
363         String[] validationErrors = null;
364         if (evaluator != null) {
365             ArrayList<String> keys = new ArrayList<String>();
366             ArrayList<Float> results = new ArrayList<Float>();
367             ArrayList<String> validationErrorsList = new ArrayList<>();
368             evaluator.EvaluateAccuracy(inferenceInOuts, inferenceResults, keys, results,
369                     validationErrorsList);
370             evaluatorKeys = new String[keys.size()];
371             evaluatorKeys = keys.toArray(evaluatorKeys);
372             evaluatorResults = new float[results.size()];
373             for (int i = 0; i < evaluatorResults.length; i++) {
374                 evaluatorResults[i] = results.get(i).floatValue();
375             }
376             validationErrors = new String[validationErrorsList.size()];
377             validationErrorsList.toArray(validationErrors);
378         }
379 
380         // Calculate inference frequency/histogram across TIME_FREQ_ARRAY_SIZE buckets.
381         float[] timeFreqSec = new float[TIME_FREQ_ARRAY_SIZE];
382         float stepSize = (maxComputeTimeSec - minComputeTimeSec) / (TIME_FREQ_ARRAY_SIZE - 1);
383         for (InferenceResult iresult : inferenceResults) {
384             timeFreqSec[(int) ((iresult.mComputeTimeSec - minComputeTimeSec) / stepSize)] += 1;
385         }
386 
387         // Calc test set size
388         int testSetSize = 0;
389         for (InferenceInOutSequence iios : inferenceInOuts) {
390             testSetSize += iios.size();
391         }
392 
393         return new BenchmarkResult(totalTime, iterations, (float) Math.sqrt(variance),
394                 sumOfMSEs, maxSingleError, testInfo, evaluatorKeys, evaluatorResults,
395                 minComputeTimeSec, stepSize, timeFreqSec, backendType, testSetSize,
396                 validationErrors);
397     }
398 }
399