1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import android.content.res.AssetManager;
20 
21 import com.android.nn.benchmark.evaluators.MelCepLogF0;
22 import com.android.nn.benchmark.evaluators.TopK;
23 import com.android.nn.benchmark.util.IOUtils;
24 
25 /**
26  * Config options for inference accuracy evaluators.
27  */
28 public class EvaluatorConfig {
29     private String className;
30 
31     // Optional.
32     private String outputMeanStdDev;
33 
34     // Optional
35     private Double expectedTop1;
36 
EvaluatorConfig(String className, String outputMeanStdDev, Double expectedTop1)37     public EvaluatorConfig(String className, String outputMeanStdDev, Double expectedTop1) {
38         this.className = className;
39         this.outputMeanStdDev = outputMeanStdDev;
40         this.expectedTop1 = expectedTop1;
41     }
42 
createEvaluator(AssetManager assetManager)43     public EvaluatorInterface createEvaluator(AssetManager assetManager) {
44         try {
45             Class<?> clazz = Class.forName(
46                     "com.android.nn.benchmark.evaluators." + className);
47             EvaluatorInterface evaluator = (EvaluatorInterface) clazz.getConstructor().newInstance();
48 
49             // TODO(pszczepaniak): Refactor this into something more managable.
50             if (clazz == MelCepLogF0.class && outputMeanStdDev != null) {
51                 ((MelCepLogF0)evaluator).setOutputMeanStdDev(new OutputMeanStdDev(
52                         IOUtils.readAsset(
53                         assetManager, outputMeanStdDev, MeanStdDev.ELEMENT_SIZE_BYTES)));
54             }
55             if (clazz == TopK.class && expectedTop1 != null) {
56                 ((TopK)evaluator).expectedTop1 = expectedTop1.floatValue();
57             }
58             return evaluator;
59         } catch (Exception e) {
60             throw new IllegalArgumentException(
61                     "Can not create evaluator named '" + className + "'", e);
62         }
63     }
64 }
65