1 package com.android.nn.benchmark.evaluators; 2 3 import com.android.nn.benchmark.core.EvaluatorInterface; 4 import com.android.nn.benchmark.core.InferenceInOut; 5 import com.android.nn.benchmark.core.InferenceInOutSequence; 6 import com.android.nn.benchmark.core.InferenceResult; 7 import com.android.nn.benchmark.core.OutputMeanStdDev; 8 import com.android.nn.benchmark.util.IOUtils; 9 10 import java.util.List; 11 12 /** 13 * Base class for (input/output)sequence-by-sequence evaluation. 14 */ 15 public abstract class BaseSequenceEvaluator implements EvaluatorInterface { 16 private OutputMeanStdDev mOutputMeanStdDev = null; 17 protected int targetOutputIndex = 0; 18 setOutputMeanStdDev(OutputMeanStdDev outputMeanStdDev)19 public void setOutputMeanStdDev(OutputMeanStdDev outputMeanStdDev) { 20 mOutputMeanStdDev = outputMeanStdDev; 21 } 22 23 @Override EvaluateAccuracy( List<InferenceInOutSequence> inferenceInOuts, List<InferenceResult> inferenceResults, List<String> outKeys, List<Float> outValues, List<String> outValidationErrors)24 public void EvaluateAccuracy( 25 List<InferenceInOutSequence> inferenceInOuts, List<InferenceResult> inferenceResults, 26 List<String> outKeys, List<Float> outValues, 27 List<String> outValidationErrors) { 28 if (inferenceInOuts.isEmpty()) { 29 throw new IllegalArgumentException("Empty inputs/outputs"); 30 } 31 32 int dataSize = inferenceInOuts.get(0).mDatasize; 33 int outputSize = inferenceInOuts.get(0).get(0).mExpectedOutputs[targetOutputIndex].length 34 / dataSize; 35 int sequenceIndex = 0; 36 int inferenceIndex = 0; 37 while (inferenceIndex < inferenceResults.size()) { 38 int sequenceLength = inferenceInOuts.get(sequenceIndex % inferenceInOuts.size()).size(); 39 float[][] outputs = new float[sequenceLength][outputSize]; 40 float[][] expectedOutputs = new float[sequenceLength][outputSize]; 41 for (int i = 0; i < sequenceLength; ++i, ++inferenceIndex) { 42 InferenceResult result = inferenceResults.get(inferenceIndex); 43 if (mOutputMeanStdDev != null) { 44 System.arraycopy( 45 mOutputMeanStdDev.denormalize( 46 IOUtils.readFloats(result.mInferenceOutput[targetOutputIndex], 47 dataSize)), 0, 48 outputs[i], 0, outputSize); 49 } else { 50 System.arraycopy( 51 IOUtils.readFloats(result.mInferenceOutput[targetOutputIndex], 52 dataSize), 0, 53 outputs[i], 0, outputSize); 54 } 55 56 InferenceInOut inOut = inferenceInOuts.get(result.mInputOutputSequenceIndex) 57 .get(result.mInputOutputIndex); 58 if (mOutputMeanStdDev != null) { 59 System.arraycopy( 60 mOutputMeanStdDev.denormalize( 61 IOUtils.readFloats(inOut.mExpectedOutputs[targetOutputIndex], 62 dataSize)), 0, 63 expectedOutputs[i], 0, outputSize); 64 } else { 65 System.arraycopy( 66 IOUtils.readFloats(inOut.mExpectedOutputs[targetOutputIndex], dataSize), 67 0, 68 expectedOutputs[i], 0, outputSize); 69 } 70 } 71 72 EvaluateSequenceAccuracy(outputs, expectedOutputs, outValidationErrors); 73 ++sequenceIndex; 74 } 75 AddValidationResult(outKeys, outValues); 76 } 77 78 EvaluateSequenceAccuracy(float[][] outputs, float[][] expectedOutputs, List<String> outValidationErrors)79 protected abstract void EvaluateSequenceAccuracy(float[][] outputs, float[][] expectedOutputs, 80 List<String> outValidationErrors); 81 AddValidationResult(List<String> keys, List<Float> values)82 protected abstract void AddValidationResult(List<String> keys, List<Float> values); 83 } 84