1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import java.util.List;
20 import java.util.ArrayList;
21 import java.util.Collections;
22 
23 /** Information about available benchmarking models */
24 public class TestModels {
25     /** Entry for a single benchmarking model */
26     public static class TestModelEntry {
27         /** Unique model name, used to find benchmark data */
28         public final String mModelName;
29 
30         /** Expected inference performance in seconds */
31         public final float mBaselineSec;
32 
33         /** Shape of input data */
34         public final int[] mInputShape;
35 
36         /** File pair asset input/output pairs */
37         public final InferenceInOutSequence.FromAssets[] mInOutAssets;
38 
39         /** Dataset inputs */
40         public final InferenceInOutSequence.FromDataset[] mInOutDatasets;
41 
42         /** Readable name for test output */
43         public final String mTestName;
44 
45         /** Name of model file, so that the same file can be reused */
46         public final String mModelFile;
47 
48         /** The evaluator to use for validating the results. */
49         public final EvaluatorConfig mEvaluator;
50 
51         /** Min SDK version that the model can run on. */
52         public final int mMinSdkVersion;
53 
TestModelEntry(String modelName, float baselineSec, int[] inputShape, InferenceInOutSequence.FromAssets[] inOutAssets, InferenceInOutSequence.FromDataset[] inOutDatasets, String testName, String modelFile, EvaluatorConfig evaluator, int minSdkVersion)54         public TestModelEntry(String modelName, float baselineSec, int[] inputShape,
55                               InferenceInOutSequence.FromAssets[] inOutAssets,
56                               InferenceInOutSequence.FromDataset[] inOutDatasets,
57                               String testName, String modelFile, EvaluatorConfig evaluator,
58                               int minSdkVersion) {
59             mModelName = modelName;
60             mBaselineSec = baselineSec;
61             mInputShape = inputShape;
62             mInOutAssets = inOutAssets;
63             mInOutDatasets  = inOutDatasets;
64             mTestName = testName;
65             mModelFile = modelFile;
66             mEvaluator = evaluator;
67             mMinSdkVersion = minSdkVersion;
68         }
69 
createNNTestBase()70         public NNTestBase createNNTestBase() {
71             return new NNTestBase(mModelName, mModelFile, mInputShape, mInOutAssets, mInOutDatasets,
72                 mEvaluator, mMinSdkVersion);
73         }
74 
createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump)75         public NNTestBase createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump) {
76             NNTestBase test = createNNTestBase();
77             test.useNNApi(useNNApi);
78             test.enableIntermediateTensorsDump(enableIntermediateTensorsDump);
79             return test;
80         }
81 
toString()82         public String toString() {
83             return mModelName;
84         }
85 
getTestName()86         public String getTestName() {
87             return mTestName;
88         }
89     }
90 
91     static private List<TestModelEntry> sTestModelEntryList = new ArrayList<>();
92     static private volatile boolean sTestModelEntryListFrozen = false;
93 
94     /** Add new benchmark model. */
registerModel(TestModelEntry model)95     static public void registerModel(TestModelEntry model) {
96         if (sTestModelEntryListFrozen) {
97             throw new IllegalStateException("Can't register new models after its list is frozen");
98         }
99         sTestModelEntryList.add(model);
100     }
101 
102     /** Fetch list of test models.
103      *
104      * If this method was called at least once, then it's impossible to register new models.
105      */
modelsList()106     static public List<TestModelEntry> modelsList() {
107         if (!sTestModelEntryListFrozen) {
108             // If this method was called once, make models list unmodifiable
109             synchronized (TestModels.class) {
110                 if (!sTestModelEntryListFrozen) {
111                     sTestModelEntryList = Collections.unmodifiableList(sTestModelEntryList);
112                     sTestModelEntryListFrozen = true;
113                 }
114             }
115         }
116         return sTestModelEntryList;
117     }
118 
119     /** Fetch model by its name. */
getModelByName(String name)120     static public TestModelEntry getModelByName(String name) {
121         for (TestModelEntry testModelEntry : modelsList()) {
122             if (testModelEntry.mModelName.equals(name)) {
123                 return testModelEntry;
124             }
125         }
126         throw new IllegalArgumentException("Unknown TestModelEntry named " + name);
127     }
128 }
129