1 package com.android.nn.benchmark.evaluators;
2 
3 import com.android.nn.benchmark.core.EvaluatorInterface;
4 import com.android.nn.benchmark.core.InferenceInOut;
5 import com.android.nn.benchmark.core.InferenceInOutSequence;
6 import com.android.nn.benchmark.core.InferenceResult;
7 import com.android.nn.benchmark.core.OutputMeanStdDev;
8 import com.android.nn.benchmark.util.IOUtils;
9 
10 import java.util.List;
11 
12 /**
13  * Base class for (input/output)sequence-by-sequence evaluation.
14  */
15 public abstract class BaseSequenceEvaluator implements EvaluatorInterface {
16     private OutputMeanStdDev mOutputMeanStdDev = null;
17     protected int targetOutputIndex = 0;
18 
setOutputMeanStdDev(OutputMeanStdDev outputMeanStdDev)19     public void setOutputMeanStdDev(OutputMeanStdDev outputMeanStdDev) {
20         mOutputMeanStdDev = outputMeanStdDev;
21     }
22 
23     @Override
EvaluateAccuracy( List<InferenceInOutSequence> inferenceInOuts, List<InferenceResult> inferenceResults, List<String> outKeys, List<Float> outValues, List<String> outValidationErrors)24     public void EvaluateAccuracy(
25             List<InferenceInOutSequence> inferenceInOuts, List<InferenceResult> inferenceResults,
26             List<String> outKeys, List<Float> outValues,
27             List<String> outValidationErrors) {
28         if (inferenceInOuts.isEmpty()) {
29             throw new IllegalArgumentException("Empty inputs/outputs");
30         }
31 
32         int dataSize = inferenceInOuts.get(0).mDatasize;
33         int outputSize = inferenceInOuts.get(0).get(0).mExpectedOutputs[targetOutputIndex].length
34                 / dataSize;
35         int sequenceIndex = 0;
36         int inferenceIndex = 0;
37         while (inferenceIndex < inferenceResults.size()) {
38             int sequenceLength = inferenceInOuts.get(sequenceIndex % inferenceInOuts.size()).size();
39             float[][] outputs = new float[sequenceLength][outputSize];
40             float[][] expectedOutputs = new float[sequenceLength][outputSize];
41             for (int i = 0; i < sequenceLength; ++i, ++inferenceIndex) {
42                 InferenceResult result = inferenceResults.get(inferenceIndex);
43                 if (mOutputMeanStdDev != null) {
44                     System.arraycopy(
45                             mOutputMeanStdDev.denormalize(
46                                     IOUtils.readFloats(result.mInferenceOutput[targetOutputIndex],
47                                             dataSize)), 0,
48                             outputs[i], 0, outputSize);
49                 } else {
50                     System.arraycopy(
51                             IOUtils.readFloats(result.mInferenceOutput[targetOutputIndex],
52                                     dataSize), 0,
53                             outputs[i], 0, outputSize);
54                 }
55 
56                 InferenceInOut inOut = inferenceInOuts.get(result.mInputOutputSequenceIndex)
57                         .get(result.mInputOutputIndex);
58                 if (mOutputMeanStdDev != null) {
59                     System.arraycopy(
60                             mOutputMeanStdDev.denormalize(
61                                     IOUtils.readFloats(inOut.mExpectedOutputs[targetOutputIndex],
62                                             dataSize)), 0,
63                             expectedOutputs[i], 0, outputSize);
64                 } else {
65                     System.arraycopy(
66                             IOUtils.readFloats(inOut.mExpectedOutputs[targetOutputIndex], dataSize),
67                             0,
68                             expectedOutputs[i], 0, outputSize);
69                 }
70             }
71 
72             EvaluateSequenceAccuracy(outputs, expectedOutputs, outValidationErrors);
73             ++sequenceIndex;
74         }
75         AddValidationResult(outKeys, outValues);
76     }
77 
78 
EvaluateSequenceAccuracy(float[][] outputs, float[][] expectedOutputs, List<String> outValidationErrors)79     protected abstract void EvaluateSequenceAccuracy(float[][] outputs, float[][] expectedOutputs,
80             List<String> outValidationErrors);
81 
AddValidationResult(List<String> keys, List<Float> values)82     protected abstract void AddValidationResult(List<String> keys, List<Float> values);
83 }
84