1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.evaluators;
18 
19 import android.util.Pair;
20 
21 import com.android.nn.benchmark.core.EvaluatorInterface;
22 import com.android.nn.benchmark.core.InferenceInOut;
23 import com.android.nn.benchmark.core.InferenceInOutSequence;
24 import com.android.nn.benchmark.core.InferenceResult;
25 import com.android.nn.benchmark.util.IOUtils;
26 
27 import java.util.Comparator;
28 import java.util.List;
29 import java.util.PriorityQueue;
30 
31 /**
32  * Accuracy evaluator for classifiers - top-k accuracy (with k=5).
33  */
34 
35 public class TopK implements EvaluatorInterface {
36 
37     public static final int K_TOP = 5;
38     public static final float VALIDATION_TOP1_THRESHOLD = 0.05f;
39     public float expectedTop1 = 0.0f;
40     public int targetOutputIndex = 0;
41 
EvaluateAccuracy( List<InferenceInOutSequence> inferenceInOuts, List<InferenceResult> inferenceResults, List<String> outKeys, List<Float> outValues, List<String> outValidationErrors)42     public void EvaluateAccuracy(
43             List<InferenceInOutSequence> inferenceInOuts,
44             List<InferenceResult> inferenceResults,
45             List<String> outKeys,
46             List<Float> outValues,
47             List<String> outValidationErrors) {
48 
49         int total = 0;
50         int[] topk = new int[K_TOP];
51         for (int i = 0; i < inferenceResults.size(); i++) {
52             InferenceResult result = inferenceResults.get(i);
53             if (result.mInferenceOutput == null) {
54                 throw new IllegalArgumentException("Needs mInferenceOutput for TopK");
55             }
56             InferenceInOutSequence sequence = inferenceInOuts.get(result.mInputOutputSequenceIndex);
57             if (sequence.size() != 1) {
58                 throw new IllegalArgumentException("Only one item in InferenceInOutSequenece " +
59                         "supported by TopK evaluator");
60             }
61             if (result.mInputOutputIndex != 0) {
62                 throw new IllegalArgumentException("Unexpected non-zero InputOutputIndex");
63             }
64             InferenceInOut io = sequence.get(0);
65             final int expectedClass = io.mExpectedClass;
66             if (expectedClass < 0) {
67                 throw new IllegalArgumentException("expected class not set");
68             }
69             PriorityQueue<Pair<Integer, Float>> sorted = new PriorityQueue<Pair<Integer, Float>>(
70                     new Comparator<Pair<Integer, Float>>() {
71                         @Override
72                         public int compare(Pair<Integer, Float> o1, Pair<Integer, Float> o2) {
73                             // Note reverse order to get highest probability first
74                             return o2.second.compareTo(o1.second);
75                         }
76                     });
77             float[] probabilities = IOUtils.readFloats(result.mInferenceOutput[targetOutputIndex],
78                     sequence.mDatasize);
79             for (int index = 0; index < probabilities.length; index++) {
80                 sorted.add(new Pair<>(index, probabilities[index]));
81             }
82             total++;
83             boolean seen = false;
84             for (int k = 0; k < K_TOP; k++) {
85                 Pair<Integer, Float> top = sorted.remove();
86                 if (top.first.intValue() == expectedClass) {
87                     seen = true;
88                 }
89                 if (seen) {
90                     topk[k]++;
91                 }
92             }
93         }
94         for (int i = 0; i < K_TOP; i++) {
95             outKeys.add("top_" + (i + 1));
96             outValues.add(new Float((float) topk[i] / (float) total));
97         }
98 
99         if (expectedTop1 > 0.0) {
100             float top1 = ((float) topk[0] / (float) total);
101             float lowestTop1 = expectedTop1 - VALIDATION_TOP1_THRESHOLD;
102             if (top1 < lowestTop1) {
103                 outValidationErrors.add(
104                         "Top 1 value is far below the validation threshold " +
105                                 String.format("%.2f%%", expectedTop1 * 100.0));
106             }
107         }
108     }
109 }
110