1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.content.res.AssetManager; 20 21 import com.android.nn.benchmark.evaluators.MelCepLogF0; 22 import com.android.nn.benchmark.evaluators.TopK; 23 import com.android.nn.benchmark.util.IOUtils; 24 25 /** 26 * Config options for inference accuracy evaluators. 27 */ 28 public class EvaluatorConfig { 29 private String className; 30 31 // Optional. 32 private String outputMeanStdDev; 33 34 // Optional 35 private Double expectedTop1; 36 EvaluatorConfig(String className, String outputMeanStdDev, Double expectedTop1)37 public EvaluatorConfig(String className, String outputMeanStdDev, Double expectedTop1) { 38 this.className = className; 39 this.outputMeanStdDev = outputMeanStdDev; 40 this.expectedTop1 = expectedTop1; 41 } 42 createEvaluator(AssetManager assetManager)43 public EvaluatorInterface createEvaluator(AssetManager assetManager) { 44 try { 45 Class<?> clazz = Class.forName( 46 "com.android.nn.benchmark.evaluators." + className); 47 EvaluatorInterface evaluator = (EvaluatorInterface) clazz.getConstructor().newInstance(); 48 49 // TODO(pszczepaniak): Refactor this into something more managable. 50 if (clazz == MelCepLogF0.class && outputMeanStdDev != null) { 51 ((MelCepLogF0)evaluator).setOutputMeanStdDev(new OutputMeanStdDev( 52 IOUtils.readAsset( 53 assetManager, outputMeanStdDev, MeanStdDev.ELEMENT_SIZE_BYTES))); 54 } 55 if (clazz == TopK.class && expectedTop1 != null) { 56 ((TopK)evaluator).expectedTop1 = expectedTop1.floatValue(); 57 } 58 return evaluator; 59 } catch (Exception e) { 60 throw new IllegalArgumentException( 61 "Can not create evaluator named '" + className + "'", e); 62 } 63 } 64 } 65