1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.evaluators; 18 19 import android.util.Pair; 20 21 import com.android.nn.benchmark.core.EvaluatorInterface; 22 import com.android.nn.benchmark.core.InferenceInOut; 23 import com.android.nn.benchmark.core.InferenceInOutSequence; 24 import com.android.nn.benchmark.core.InferenceResult; 25 import com.android.nn.benchmark.util.IOUtils; 26 27 import java.util.Comparator; 28 import java.util.List; 29 import java.util.PriorityQueue; 30 31 /** 32 * Accuracy evaluator for classifiers - top-k accuracy (with k=5). 33 */ 34 35 public class TopK implements EvaluatorInterface { 36 37 public static final int K_TOP = 5; 38 public static final float VALIDATION_TOP1_THRESHOLD = 0.05f; 39 public float expectedTop1 = 0.0f; 40 public int targetOutputIndex = 0; 41 EvaluateAccuracy( List<InferenceInOutSequence> inferenceInOuts, List<InferenceResult> inferenceResults, List<String> outKeys, List<Float> outValues, List<String> outValidationErrors)42 public void EvaluateAccuracy( 43 List<InferenceInOutSequence> inferenceInOuts, 44 List<InferenceResult> inferenceResults, 45 List<String> outKeys, 46 List<Float> outValues, 47 List<String> outValidationErrors) { 48 49 int total = 0; 50 int[] topk = new int[K_TOP]; 51 for (int i = 0; i < inferenceResults.size(); i++) { 52 InferenceResult result = inferenceResults.get(i); 53 if (result.mInferenceOutput == null) { 54 throw new IllegalArgumentException("Needs mInferenceOutput for TopK"); 55 } 56 InferenceInOutSequence sequence = inferenceInOuts.get(result.mInputOutputSequenceIndex); 57 if (sequence.size() != 1) { 58 throw new IllegalArgumentException("Only one item in InferenceInOutSequenece " + 59 "supported by TopK evaluator"); 60 } 61 if (result.mInputOutputIndex != 0) { 62 throw new IllegalArgumentException("Unexpected non-zero InputOutputIndex"); 63 } 64 InferenceInOut io = sequence.get(0); 65 final int expectedClass = io.mExpectedClass; 66 if (expectedClass < 0) { 67 throw new IllegalArgumentException("expected class not set"); 68 } 69 PriorityQueue<Pair<Integer, Float>> sorted = new PriorityQueue<Pair<Integer, Float>>( 70 new Comparator<Pair<Integer, Float>>() { 71 @Override 72 public int compare(Pair<Integer, Float> o1, Pair<Integer, Float> o2) { 73 // Note reverse order to get highest probability first 74 return o2.second.compareTo(o1.second); 75 } 76 }); 77 float[] probabilities = IOUtils.readFloats(result.mInferenceOutput[targetOutputIndex], 78 sequence.mDatasize); 79 for (int index = 0; index < probabilities.length; index++) { 80 sorted.add(new Pair<>(index, probabilities[index])); 81 } 82 total++; 83 boolean seen = false; 84 for (int k = 0; k < K_TOP; k++) { 85 Pair<Integer, Float> top = sorted.remove(); 86 if (top.first.intValue() == expectedClass) { 87 seen = true; 88 } 89 if (seen) { 90 topk[k]++; 91 } 92 } 93 } 94 for (int i = 0; i < K_TOP; i++) { 95 outKeys.add("top_" + (i + 1)); 96 outValues.add(new Float((float) topk[i] / (float) total)); 97 } 98 99 if (expectedTop1 > 0.0) { 100 float top1 = ((float) topk[0] / (float) total); 101 float lowestTop1 = expectedTop1 - VALIDATION_TOP1_THRESHOLD; 102 if (top1 < lowestTop1) { 103 outValidationErrors.add( 104 "Top 1 value is far below the validation threshold " + 105 String.format("%.2f%%", expectedTop1 * 100.0)); 106 } 107 } 108 } 109 } 110