1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.os.Bundle; 20 import android.os.Parcel; 21 import android.os.Parcelable; 22 23 public class LatencyResult implements Parcelable { 24 private final static int TIME_FREQ_ARRAY_SIZE = 32; 25 26 private float mTotalTimeSec; 27 private int mIterations; 28 private float mTimeStdDeviation; 29 30 /** Time offset for inference frequency counts */ 31 private float mTimeFreqStartSec; 32 33 /** Index time offset for inference frequency counts */ 34 private float mTimeFreqStepSec; 35 36 /** 37 * Array of inference frequency counts. 38 * Each entry contains inference count for time range: 39 * [mTimeFreqStartSec + i*mTimeFreqStepSec, mTimeFreqStartSec + (1+i*mTimeFreqStepSec) 40 */ 41 private float[] mTimeFreqSec = {}; 42 LatencyResult(float[] results)43 public LatencyResult(float[] results) { 44 mIterations = results.length; 45 mTotalTimeSec = 0.0f; 46 float maxComputeTimeSec = 0.0f; 47 float minComputeTimeSec = Float.MAX_VALUE; 48 for (float result : results) { 49 mTotalTimeSec += result; 50 maxComputeTimeSec = Math.max(maxComputeTimeSec, result); 51 minComputeTimeSec = Math.min(minComputeTimeSec, result); 52 } 53 54 // Calculate standard deviation. 55 float latencyMean = (mTotalTimeSec / mIterations); 56 float variance = 0.0f; 57 for (float result : results) { 58 float v = (result - latencyMean); 59 variance += v * v; 60 } 61 variance /= mIterations; 62 mTimeStdDeviation = (float) Math.sqrt(variance); 63 64 // Calculate inference frequency/histogram across TIME_FREQ_ARRAY_SIZE buckets. 65 mTimeFreqStartSec = minComputeTimeSec; 66 mTimeFreqStepSec = (maxComputeTimeSec - minComputeTimeSec) / (TIME_FREQ_ARRAY_SIZE - 1); 67 mTimeFreqSec = new float[TIME_FREQ_ARRAY_SIZE]; 68 for (float result : results) { 69 int bucketIndex = (int) ((result - minComputeTimeSec) / mTimeFreqStepSec); 70 mTimeFreqSec[bucketIndex] += 1; 71 } 72 } 73 LatencyResult(Parcel in)74 public LatencyResult(Parcel in) { 75 mTotalTimeSec = in.readFloat(); 76 mIterations = in.readInt(); 77 mTimeStdDeviation = in.readFloat(); 78 mTimeFreqStartSec = in.readFloat(); 79 mTimeFreqStepSec = in.readFloat(); 80 int timeFreqSecLength = in.readInt(); 81 mTimeFreqSec = new float[timeFreqSecLength]; 82 in.readFloatArray(mTimeFreqSec); 83 } 84 85 @Override describeContents()86 public int describeContents() { 87 return 0; 88 } 89 90 @Override writeToParcel(Parcel dest, int flags)91 public void writeToParcel(Parcel dest, int flags) { 92 dest.writeFloat(mTotalTimeSec); 93 dest.writeInt(mIterations); 94 dest.writeFloat(mTimeStdDeviation); 95 dest.writeFloat(mTimeFreqStartSec); 96 dest.writeFloat(mTimeFreqStepSec); 97 dest.writeInt(mTimeFreqSec.length); 98 dest.writeFloatArray(mTimeFreqSec); 99 } 100 101 public static final Parcelable.Creator<LatencyResult> CREATOR = 102 new Parcelable.Creator<LatencyResult>() { 103 public LatencyResult createFromParcel(Parcel in) { 104 return new LatencyResult(in); 105 } 106 107 public LatencyResult[] newArray(int size) { 108 return new LatencyResult[size]; 109 } 110 }; 111 putToBundle(Bundle results, String prefix)112 public void putToBundle(Bundle results, String prefix) { 113 // Reported in ms 114 results.putFloat(prefix + "_avg", getMeanTimeSec() * 1000.0f); 115 results.putFloat(prefix + "_std_dev", mTimeStdDeviation * 1000.0f); 116 results.putFloat(prefix + "_total_time", mTotalTimeSec * 1000.0f); 117 results.putInt(prefix + "_iterations", mIterations); 118 } 119 120 @Override toString()121 public String toString() { 122 return "LatencyResult{" 123 + "getMeanTimeSec()=" + getMeanTimeSec() 124 + ", mTotalTimeSec=" + mTotalTimeSec 125 + ", mIterations=" + mIterations 126 + ", mTimeStdDeviation=" + mTimeStdDeviation 127 + ", mTimeFreqStartSec=" + mTimeFreqStartSec 128 + ", mTimeFreqStepSec=" + mTimeFreqStepSec + "}"; 129 } 130 getIterations()131 public int getIterations() { return mIterations; } 132 getMeanTimeSec()133 public float getMeanTimeSec() { return mTotalTimeSec / mIterations; } 134 rebase(float v, float baselineSec)135 private float rebase(float v, float baselineSec) { 136 if (v > 0.001) { 137 v = baselineSec / v; 138 } 139 return v; 140 } 141 getSummary(float baselineSec)142 public String getSummary(float baselineSec) { 143 java.text.DecimalFormat df = new java.text.DecimalFormat("######.##"); 144 return df.format(rebase(getMeanTimeSec(), baselineSec)) + "X, n=" + mIterations 145 + ", μ=" + df.format(getMeanTimeSec() * 1000.0) 146 + "ms, σ=" + df.format(mTimeStdDeviation * 1000.0) + "ms"; 147 } 148 appendToCsvLine(StringBuilder sb)149 public void appendToCsvLine(StringBuilder sb) { 150 sb.append(',').append(String.join(",", 151 String.valueOf(mIterations), 152 String.valueOf(mTotalTimeSec), 153 String.valueOf(mTimeFreqStartSec), 154 String.valueOf(mTimeFreqStepSec), 155 String.valueOf(mTimeFreqSec.length))); 156 157 for (float value : mTimeFreqSec) { 158 sb.append(',').append(value); 159 } 160 } 161 } 162