1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import android.os.Bundle;
20 import android.os.Parcel;
21 import android.os.Parcelable;
22 
23 public class LatencyResult implements Parcelable {
24     private final static int TIME_FREQ_ARRAY_SIZE = 32;
25 
26     private float mTotalTimeSec;
27     private int mIterations;
28     private float mTimeStdDeviation;
29 
30     /** Time offset for inference frequency counts */
31     private float mTimeFreqStartSec;
32 
33     /** Index time offset for inference frequency counts */
34     private float mTimeFreqStepSec;
35 
36     /**
37      * Array of inference frequency counts.
38      * Each entry contains inference count for time range:
39      * [mTimeFreqStartSec + i*mTimeFreqStepSec, mTimeFreqStartSec + (1+i*mTimeFreqStepSec)
40      */
41     private float[] mTimeFreqSec = {};
42 
LatencyResult(float[] results)43     public LatencyResult(float[] results) {
44         mIterations = results.length;
45         mTotalTimeSec = 0.0f;
46         float maxComputeTimeSec = 0.0f;
47         float minComputeTimeSec = Float.MAX_VALUE;
48         for (float result : results) {
49             mTotalTimeSec += result;
50             maxComputeTimeSec = Math.max(maxComputeTimeSec, result);
51             minComputeTimeSec = Math.min(minComputeTimeSec, result);
52         }
53 
54         // Calculate standard deviation.
55         float latencyMean = (mTotalTimeSec / mIterations);
56         float variance = 0.0f;
57         for (float result : results) {
58             float v = (result - latencyMean);
59             variance += v * v;
60         }
61         variance /= mIterations;
62         mTimeStdDeviation = (float) Math.sqrt(variance);
63 
64         // Calculate inference frequency/histogram across TIME_FREQ_ARRAY_SIZE buckets.
65         mTimeFreqStartSec = minComputeTimeSec;
66         mTimeFreqStepSec = (maxComputeTimeSec - minComputeTimeSec) / (TIME_FREQ_ARRAY_SIZE - 1);
67         mTimeFreqSec = new float[TIME_FREQ_ARRAY_SIZE];
68         for (float result : results) {
69             int bucketIndex = (int) ((result - minComputeTimeSec) / mTimeFreqStepSec);
70             mTimeFreqSec[bucketIndex] += 1;
71         }
72     }
73 
LatencyResult(Parcel in)74     public LatencyResult(Parcel in) {
75         mTotalTimeSec = in.readFloat();
76         mIterations = in.readInt();
77         mTimeStdDeviation = in.readFloat();
78         mTimeFreqStartSec = in.readFloat();
79         mTimeFreqStepSec = in.readFloat();
80         int timeFreqSecLength = in.readInt();
81         mTimeFreqSec = new float[timeFreqSecLength];
82         in.readFloatArray(mTimeFreqSec);
83     }
84 
85     @Override
describeContents()86     public int describeContents() {
87         return 0;
88     }
89 
90     @Override
writeToParcel(Parcel dest, int flags)91     public void writeToParcel(Parcel dest, int flags) {
92         dest.writeFloat(mTotalTimeSec);
93         dest.writeInt(mIterations);
94         dest.writeFloat(mTimeStdDeviation);
95         dest.writeFloat(mTimeFreqStartSec);
96         dest.writeFloat(mTimeFreqStepSec);
97         dest.writeInt(mTimeFreqSec.length);
98         dest.writeFloatArray(mTimeFreqSec);
99     }
100 
101     public static final Parcelable.Creator<LatencyResult> CREATOR =
102         new Parcelable.Creator<LatencyResult>() {
103           public LatencyResult createFromParcel(Parcel in) {
104             return new LatencyResult(in);
105           }
106 
107           public LatencyResult[] newArray(int size) {
108             return new LatencyResult[size];
109           }
110         };
111 
putToBundle(Bundle results, String prefix)112     public void putToBundle(Bundle results, String prefix) {
113         // Reported in ms
114         results.putFloat(prefix + "_avg", getMeanTimeSec() * 1000.0f);
115         results.putFloat(prefix + "_std_dev", mTimeStdDeviation * 1000.0f);
116         results.putFloat(prefix + "_total_time", mTotalTimeSec * 1000.0f);
117         results.putInt(prefix + "_iterations", mIterations);
118     }
119 
120     @Override
toString()121     public String toString() {
122         return "LatencyResult{"
123                 + "getMeanTimeSec()=" + getMeanTimeSec()
124                 + ", mTotalTimeSec=" + mTotalTimeSec
125                 + ", mIterations=" + mIterations
126                 + ", mTimeStdDeviation=" + mTimeStdDeviation
127                 + ", mTimeFreqStartSec=" + mTimeFreqStartSec
128                 + ", mTimeFreqStepSec=" + mTimeFreqStepSec + "}";
129     }
130 
getIterations()131     public int getIterations() { return mIterations; }
132 
getMeanTimeSec()133     public float getMeanTimeSec() { return mTotalTimeSec / mIterations; }
134 
rebase(float v, float baselineSec)135     private float rebase(float v, float baselineSec) {
136         if (v > 0.001) {
137             v = baselineSec / v;
138         }
139         return v;
140     }
141 
getSummary(float baselineSec)142     public String getSummary(float baselineSec) {
143         java.text.DecimalFormat df = new java.text.DecimalFormat("######.##");
144         return df.format(rebase(getMeanTimeSec(), baselineSec)) + "X, n=" + mIterations
145                 + ", μ=" + df.format(getMeanTimeSec() * 1000.0)
146                 + "ms, σ=" + df.format(mTimeStdDeviation * 1000.0) + "ms";
147     }
148 
appendToCsvLine(StringBuilder sb)149     public void appendToCsvLine(StringBuilder sb) {
150         sb.append(',').append(String.join(",",
151             String.valueOf(mIterations),
152             String.valueOf(mTotalTimeSec),
153             String.valueOf(mTimeFreqStartSec),
154             String.valueOf(mTimeFreqStepSec),
155             String.valueOf(mTimeFreqSec.length)));
156 
157         for (float value : mTimeFreqSec) {
158             sb.append(',').append(value);
159         }
160     }
161 }
162