1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import android.os.Bundle;
20 import android.os.Parcel;
21 import android.os.Parcelable;
22 import android.text.TextUtils;
23 import android.util.Pair;
24 
25 import java.util.ArrayList;
26 import java.util.Arrays;
27 import java.util.List;
28 
29 public class BenchmarkResult implements Parcelable {
30     // Used by CTS tests.
31     public final static String BACKEND_TFLITE_NNAPI = "TFLite_NNAPI";
32     public final static String BACKEND_TFLITE_CPU = "TFLite_CPU";
33 
34     private final static int TIME_FREQ_ARRAY_SIZE = 32;
35 
36     /** The name of the benchmark */
37     private String mTestInfo;
38 
39     /** Latency results */
40     private LatencyResult mLatencyInference;
41     private LatencyResult mLatencyCompileWithoutCache;
42     private LatencyResult mLatencySaveToCache;
43     private LatencyResult mLatencyPrepareFromCache;
44 
45     /** Accuracy results */
46     private float mSumOfMSEs;
47     private float mMaxSingleError;
48     private int mNumberOfEvaluatorResults;
49     private String[] mEvaluatorKeys = {};
50     private float[] mEvaluatorResults = {};
51 
52     /** Type of backend used for inference */
53     private String mBackendType;
54 
55     /** Size of test set using for inference */
56     private int mTestSetSize;
57 
58     /** Size of compilation cache files in bytes */
59     private int mCompilationCacheSizeBytes = 0;
60 
61     /** List of validation errors */
62     private String[] mValidationErrors = {};
63 
64     /** Error that prevents the benchmark from running, e.g. SDK version not supported. */
65     private String mBenchmarkError;
66 
BenchmarkResult(LatencyResult inferenceLatency, float sumOfMSEs, float maxSingleError, String testInfo, String[] evaluatorKeys, float[] evaluatorResults, String backendType, int testSetSize, String[] validationErrors)67     public BenchmarkResult(LatencyResult inferenceLatency,
68             float sumOfMSEs, float maxSingleError, String testInfo,
69             String[] evaluatorKeys, float[] evaluatorResults,
70             String backendType, int testSetSize, String[] validationErrors) {
71         mLatencyInference = inferenceLatency;
72         mSumOfMSEs = sumOfMSEs;
73         mMaxSingleError = maxSingleError;
74         mTestInfo = testInfo;
75         mBackendType = backendType;
76         mTestSetSize = testSetSize;
77         if (validationErrors == null) {
78             mValidationErrors = new String[0];
79         } else {
80             mValidationErrors = validationErrors;
81         }
82 
83         if (evaluatorKeys == null) {
84             mEvaluatorKeys = new String[0];
85         } else {
86             mEvaluatorKeys = evaluatorKeys;
87         }
88         if (evaluatorResults == null) {
89             mEvaluatorResults = new float[0];
90         } else {
91             mEvaluatorResults = evaluatorResults;
92         }
93         if (mEvaluatorResults.length != mEvaluatorKeys.length) {
94             throw new IllegalArgumentException("Different number of evaluator keys vs values");
95         }
96         mNumberOfEvaluatorResults = mEvaluatorResults.length;
97     }
98 
BenchmarkResult(String benchmarkError)99     public BenchmarkResult(String benchmarkError) {
100         mBenchmarkError = benchmarkError;
101     }
102 
hasValidationErrors()103     public boolean hasValidationErrors() {
104         return mValidationErrors.length > 0;
105     }
106 
BenchmarkResult(Parcel in)107     protected BenchmarkResult(Parcel in) {
108         mLatencyInference = in.readParcelable(LatencyResult.class.getClassLoader());
109         mLatencyCompileWithoutCache = in.readParcelable(LatencyResult.class.getClassLoader());
110         mLatencySaveToCache = in.readParcelable(LatencyResult.class.getClassLoader());
111         mLatencyPrepareFromCache = in.readParcelable(LatencyResult.class.getClassLoader());
112         mSumOfMSEs = in.readFloat();
113         mMaxSingleError = in.readFloat();
114         mTestInfo = in.readString();
115         mNumberOfEvaluatorResults = in.readInt();
116         mEvaluatorKeys = new String[mNumberOfEvaluatorResults];
117         in.readStringArray(mEvaluatorKeys);
118         mEvaluatorResults = new float[mNumberOfEvaluatorResults];
119         in.readFloatArray(mEvaluatorResults);
120         if (mEvaluatorResults.length != mEvaluatorKeys.length) {
121             throw new IllegalArgumentException("Different number of evaluator keys vs values");
122         }
123         mBackendType = in.readString();
124         mTestSetSize = in.readInt();
125         mCompilationCacheSizeBytes = in.readInt();
126         int validationsErrorsSize = in.readInt();
127         mValidationErrors = new String[validationsErrorsSize];
128         in.readStringArray(mValidationErrors);
129         mBenchmarkError = in.readString();
130     }
131 
132     @Override
describeContents()133     public int describeContents() {
134         return 0;
135     }
136 
137     @Override
writeToParcel(Parcel dest, int flags)138     public void writeToParcel(Parcel dest, int flags) {
139         dest.writeParcelable(mLatencyInference, flags);
140         dest.writeParcelable(mLatencyCompileWithoutCache, flags);
141         dest.writeParcelable(mLatencySaveToCache, flags);
142         dest.writeParcelable(mLatencyPrepareFromCache, flags);
143         dest.writeFloat(mSumOfMSEs);
144         dest.writeFloat(mMaxSingleError);
145         dest.writeString(mTestInfo);
146         dest.writeInt(mNumberOfEvaluatorResults);
147         dest.writeStringArray(mEvaluatorKeys);
148         dest.writeFloatArray(mEvaluatorResults);
149         dest.writeString(mBackendType);
150         dest.writeInt(mTestSetSize);
151         dest.writeInt(mCompilationCacheSizeBytes);
152         dest.writeInt(mValidationErrors.length);
153         dest.writeStringArray(mValidationErrors);
154         dest.writeString(mBenchmarkError);
155     }
156 
157     @SuppressWarnings("unused")
158     public static final Parcelable.Creator<BenchmarkResult> CREATOR =
159             new Parcelable.Creator<BenchmarkResult>() {
160                 @Override
161                 public BenchmarkResult createFromParcel(Parcel in) {
162                     return new BenchmarkResult(in);
163                 }
164 
165                 @Override
166                 public BenchmarkResult[] newArray(int size) {
167                     return new BenchmarkResult[size];
168                 }
169             };
170 
getError()171     public float getError() {
172         return mSumOfMSEs;
173     }
174 
getMeanTimeSec()175     public float getMeanTimeSec() {
176         return mLatencyInference.getMeanTimeSec();
177     }
178 
getCompileWithoutCacheMeanTimeSec()179     public float getCompileWithoutCacheMeanTimeSec() {
180         return mLatencyCompileWithoutCache == null ? 0.0f
181             : mLatencyCompileWithoutCache.getMeanTimeSec();
182     }
183 
getSaveToCacheMeanTimeSec()184     public float getSaveToCacheMeanTimeSec() {
185         return mLatencySaveToCache == null ? 0.0f : mLatencySaveToCache.getMeanTimeSec();
186     }
187 
getPrepareFromCacheMeanTimeSec()188     public float getPrepareFromCacheMeanTimeSec() {
189         return mLatencyPrepareFromCache == null ? 0.0f : mLatencyPrepareFromCache.getMeanTimeSec();
190     }
191 
getEvaluatorResults()192     public List<Pair<String, Float>> getEvaluatorResults() {
193         List<Pair<String, Float>> results = new ArrayList<>();
194         for (int i = 0; i < mEvaluatorKeys.length; ++i) {
195             results.add(new Pair<>(mEvaluatorKeys[i], mEvaluatorResults[i]));
196         }
197         return results;
198     }
199 
200     @Override
toString()201     public String toString() {
202         if (!TextUtils.isEmpty(mBenchmarkError)) {
203             return mBenchmarkError;
204         }
205 
206         StringBuilder result = new StringBuilder("BenchmarkResult{" +
207                 "mTestInfo='" + mTestInfo + '\'' +
208                 ", mLatencyInference=" + mLatencyInference.toString() +
209                 ", mSumOfMSEs=" + mSumOfMSEs +
210                 ", mMaxSingleErrors=" + mMaxSingleError +
211                 ", mBackendType=" + mBackendType +
212                 ", mTestSetSize=" + mTestSetSize);
213         for (int i = 0; i < mEvaluatorKeys.length; i++) {
214             result.append(", ").append(mEvaluatorKeys[i]).append("=").append(mEvaluatorResults[i]);
215         }
216 
217         result.append(", mValidationErrors=[");
218         for (int i = 0; i < mValidationErrors.length; i++) {
219             result.append(mValidationErrors[i]);
220             if (i < mValidationErrors.length - 1) {
221                 result.append(",");
222             }
223         }
224         result.append("]");
225 
226         if (mLatencyCompileWithoutCache != null) {
227             result.append(", mLatencyCompileWithoutCache=")
228                     .append(mLatencyCompileWithoutCache.toString());
229         }
230         if (mLatencySaveToCache != null) {
231             result.append(", mLatencySaveToCache=").append(mLatencySaveToCache.toString());
232         }
233         if (mLatencyPrepareFromCache != null) {
234             result.append(", mLatencyPrepareFromCache=")
235                     .append(mLatencyPrepareFromCache.toString());
236         }
237         result.append(", mCompilationCacheSizeBytes=").append(mCompilationCacheSizeBytes);
238 
239         result.append('}');
240         return result.toString();
241     }
242 
hasBenchmarkError()243     public boolean hasBenchmarkError() {
244         return !TextUtils.isEmpty(mBenchmarkError);
245     }
246 
getBenchmarkError()247     public String getBenchmarkError() {
248         if (!hasBenchmarkError()) return null;
249 
250         return mBenchmarkError;
251     }
252 
setBenchmarkError(String benchmarkError)253     public void setBenchmarkError(String benchmarkError) {
254         mBenchmarkError = benchmarkError;
255     }
256 
getSummary(float baselineSec)257     public String getSummary(float baselineSec) {
258         if (hasBenchmarkError()) {
259             return getBenchmarkError();
260         }
261         return mLatencyInference.getSummary(baselineSec);
262     }
263 
toBundle(String testName)264     public Bundle toBundle(String testName) {
265         Bundle results = new Bundle();
266         if (!TextUtils.isEmpty(mBenchmarkError)) {
267             results.putString(testName + "_error", mBenchmarkError);
268             return results;
269         }
270 
271         mLatencyInference.putToBundle(results, testName + "_inference");
272         results.putFloat(testName + "_inference_mean_square_error",
273                 mSumOfMSEs / mLatencyInference.getIterations());
274         results.putFloat(testName + "_inference_max_single_error", mMaxSingleError);
275         for (int i = 0; i < mEvaluatorKeys.length; i++) {
276             results.putFloat(testName + "_inference_" + mEvaluatorKeys[i], mEvaluatorResults[i]);
277         }
278         if (mLatencyCompileWithoutCache != null) {
279             mLatencyCompileWithoutCache.putToBundle(results, testName + "_compile_without_cache");
280         }
281         if (mLatencySaveToCache != null) {
282             mLatencySaveToCache.putToBundle(results, testName + "_save_to_cache");
283         }
284         if (mLatencyPrepareFromCache != null) {
285             mLatencyPrepareFromCache.putToBundle(results, testName + "_prepare_from_cache");
286         }
287         if (mCompilationCacheSizeBytes > 0) {
288             results.putInt(testName + "_compilation_cache_size", mCompilationCacheSizeBytes);
289         }
290         return results;
291     }
292 
293     @SuppressWarnings("AndroidJdkLibsChecker")
toCsvLine()294     public String toCsvLine() {
295         if (!TextUtils.isEmpty(mBenchmarkError)) {
296             return "";
297         }
298 
299         StringBuilder sb = new StringBuilder();
300         sb.append(mTestInfo).append(',').append(mBackendType);
301 
302         mLatencyInference.appendToCsvLine(sb);
303 
304         sb.append(',').append(String.join(",",
305             String.valueOf(mMaxSingleError),
306             String.valueOf(mTestSetSize),
307             String.valueOf(mEvaluatorKeys.length),
308             String.valueOf(mValidationErrors.length)));
309 
310         for (int i = 0; i < mEvaluatorKeys.length; ++i) {
311             sb.append(',').append(mEvaluatorKeys[i]);
312         }
313 
314         for (int i = 0; i < mEvaluatorKeys.length; ++i) {
315             sb.append(',').append(mEvaluatorResults[i]);
316         }
317 
318         for (String validationError : mValidationErrors) {
319             sb.append(',').append(validationError.replace(',', ' '));
320         }
321 
322         sb.append(',').append(mLatencyCompileWithoutCache != null);
323         if (mLatencyCompileWithoutCache != null) {
324             mLatencyCompileWithoutCache.appendToCsvLine(sb);
325         }
326         sb.append(',').append(mLatencySaveToCache != null);
327         if (mLatencySaveToCache != null) {
328             mLatencySaveToCache.appendToCsvLine(sb);
329         }
330         sb.append(',').append(mLatencyPrepareFromCache != null);
331         if (mLatencyPrepareFromCache != null) {
332             mLatencyPrepareFromCache.appendToCsvLine(sb);
333         }
334         sb.append(',').append(mCompilationCacheSizeBytes);
335 
336         sb.append('\n');
337         return sb.toString();
338     }
339 
fromInferenceResults( String testInfo, String backendType, List<InferenceInOutSequence> inferenceInOuts, List<InferenceResult> inferenceResults, EvaluatorInterface evaluator)340     public static BenchmarkResult fromInferenceResults(
341             String testInfo,
342             String backendType,
343             List<InferenceInOutSequence> inferenceInOuts,
344             List<InferenceResult> inferenceResults,
345             EvaluatorInterface evaluator) {
346         float[] latencies = new float[inferenceResults.size()];
347         float sumOfMSEs = 0;
348         float maxSingleError = 0;
349         for (int i = 0; i < inferenceResults.size(); i++) {
350             InferenceResult iresult = inferenceResults.get(i);
351             latencies[i] = iresult.mComputeTimeSec;
352             if (iresult.mMeanSquaredErrors != null) {
353                 for (float mse : iresult.mMeanSquaredErrors) {
354                     sumOfMSEs += mse;
355                 }
356             }
357             if (iresult.mMaxSingleErrors != null) {
358                 for (float mse : iresult.mMaxSingleErrors) {
359                     if (mse > maxSingleError) {
360                         maxSingleError = mse;
361                     }
362                 }
363             }
364         }
365 
366         String[] evaluatorKeys = null;
367         float[] evaluatorResults = null;
368         String[] validationErrors = null;
369         if (evaluator != null) {
370             ArrayList<String> keys = new ArrayList<String>();
371             ArrayList<Float> results = new ArrayList<Float>();
372             ArrayList<String> validationErrorsList = new ArrayList<>();
373             evaluator.EvaluateAccuracy(inferenceInOuts, inferenceResults, keys, results,
374                     validationErrorsList);
375             evaluatorKeys = new String[keys.size()];
376             evaluatorKeys = keys.toArray(evaluatorKeys);
377             evaluatorResults = new float[results.size()];
378             for (int i = 0; i < evaluatorResults.length; i++) {
379                 evaluatorResults[i] = results.get(i).floatValue();
380             }
381             validationErrors = new String[validationErrorsList.size()];
382             validationErrorsList.toArray(validationErrors);
383         }
384 
385         // Calc test set size
386         int testSetSize = 0;
387         for (InferenceInOutSequence iios : inferenceInOuts) {
388             testSetSize += iios.size();
389         }
390 
391         return new BenchmarkResult(new LatencyResult(latencies), sumOfMSEs, maxSingleError,
392                 testInfo, evaluatorKeys, evaluatorResults, backendType, testSetSize,
393                 validationErrors);
394     }
395 
setCompilationBenchmarkResult(CompilationBenchmarkResult result)396     public void setCompilationBenchmarkResult(CompilationBenchmarkResult result) {
397         mLatencyCompileWithoutCache = new LatencyResult(result.mCompileWithoutCacheTimeSec);
398         if (result.mSaveToCacheTimeSec != null) {
399             mLatencySaveToCache = new LatencyResult(result.mSaveToCacheTimeSec);
400         }
401         if (result.mPrepareFromCacheTimeSec != null) {
402             mLatencyPrepareFromCache = new LatencyResult(result.mPrepareFromCacheTimeSec);
403         }
404         mCompilationCacheSizeBytes = result.mCacheSizeBytes;
405     }
406 }
407