1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import java.util.ArrayList;
20 import java.util.List;
21 import java.util.concurrent.atomic.AtomicReference;
22 
23 /** Information about available benchmarking models */
24 public class TestModels {
25     /** Entry for a single benchmarking model */
26     public static class TestModelEntry {
27         /** Unique model name, used to find benchmark data */
28         public final String mModelName;
29 
30         /** Expected inference performance in seconds */
31         public final float mBaselineSec;
32 
33         /** Shape of input data */
34         public final int[] mInputShape;
35 
36         /** File pair asset input/output pairs */
37         public final InferenceInOutSequence.FromAssets[] mInOutAssets;
38 
39         /** Dataset inputs */
40         public final InferenceInOutSequence.FromDataset[] mInOutDatasets;
41 
42         /** Readable name for test output */
43         public final String mTestName;
44 
45         /** Name of model file, so that the same file can be reused */
46         public final String mModelFile;
47 
48         /** The evaluator to use for validating the results. */
49         public final EvaluatorConfig mEvaluator;
50 
51         /** Min SDK version that the model can run on. */
52         public final int mMinSdkVersion;
53 
54         /* Number of bytes per input data entry  */
55         public final int mInDataSize;
56 
TestModelEntry(String modelName, float baselineSec, int[] inputShape, InferenceInOutSequence.FromAssets[] inOutAssets, InferenceInOutSequence.FromDataset[] inOutDatasets, String testName, String modelFile, EvaluatorConfig evaluator, int minSdkVersion, int inDataSize)57         public TestModelEntry(String modelName, float baselineSec, int[] inputShape,
58                 InferenceInOutSequence.FromAssets[] inOutAssets,
59                 InferenceInOutSequence.FromDataset[] inOutDatasets, String testName,
60                 String modelFile,
61                 EvaluatorConfig evaluator, int minSdkVersion, int inDataSize) {
62             mModelName = modelName;
63             mBaselineSec = baselineSec;
64             mInputShape = inputShape;
65             mInOutAssets = inOutAssets;
66             mInOutDatasets = inOutDatasets;
67             mTestName = testName;
68             mModelFile = modelFile;
69             mEvaluator = evaluator;
70             mMinSdkVersion = minSdkVersion;
71             mInDataSize = inDataSize;
72         }
73 
createNNTestBase()74         public NNTestBase createNNTestBase() {
75             return new NNTestBase(mModelName, mModelFile, mInputShape, mInOutAssets, mInOutDatasets,
76                     mEvaluator, mMinSdkVersion);
77         }
78 
createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump)79         public NNTestBase createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump) {
80             return createNNTestBase(tfLiteBackend, enableIntermediateTensorsDump, /*mmapModel=*/false);
81         }
82 
83         // Used by CTS tests.
createNNTestBase(boolean useNNAPI, boolean enableIntermediateTensorsDump)84         public NNTestBase createNNTestBase(boolean useNNAPI, boolean enableIntermediateTensorsDump) {
85             TfLiteBackend tfLiteBackend = useNNAPI ? TfLiteBackend.NNAPI : TfLiteBackend.CPU;
86             return createNNTestBase(tfLiteBackend, enableIntermediateTensorsDump, /*mmapModel=*/false);
87         }
88 
createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump, boolean mmapModel)89         public NNTestBase createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump,
90                 boolean mmapModel) {
91             NNTestBase test = createNNTestBase();
92             test.setTfLiteBackend(tfLiteBackend);
93             test.enableIntermediateTensorsDump(enableIntermediateTensorsDump);
94             test.setMmapModel(mmapModel);
95             return test;
96         }
97 
toString()98         public String toString() {
99             return mModelName;
100         }
101 
getTestName()102         public String getTestName() {
103             return mTestName;
104         }
105 
106 
withDisabledEvaluation()107         public TestModelEntry withDisabledEvaluation() {
108             return new TestModelEntry(mModelName, mBaselineSec, mInputShape, mInOutAssets,
109                     mInOutDatasets, mTestName, mModelFile,
110                     null, // Disable evaluation.
111                     mMinSdkVersion, mInDataSize);
112         }
113     }
114 
115     static private final List<TestModelEntry> sTestModelEntryList = new ArrayList<>();
116     static private final AtomicReference<List<TestModelEntry>> frozenEntries =
117             new AtomicReference<>(null);
118 
119 
120     /** Add new benchmark model. */
registerModel(TestModelEntry model)121     static public void registerModel(TestModelEntry model) {
122         if (frozenEntries.get() != null) {
123             throw new IllegalStateException("Can't register new models after its list is frozen");
124         }
125         sTestModelEntryList.add(model);
126     }
127 
isListFrozen()128     public static boolean isListFrozen() {
129         return frozenEntries.get() != null;
130     }
131 
132     /**
133      * Fetch list of test models.
134      *
135      * If this method was called at least once, then it's impossible to register new models.
136      */
modelsList()137     static public List<TestModelEntry> modelsList() {
138         frozenEntries.compareAndSet(null, sTestModelEntryList);
139         return frozenEntries.get();
140     }
141 
142     /** Fetch model by its name. */
getModelByName(String name)143     static public TestModelEntry getModelByName(String name) {
144         for (TestModelEntry testModelEntry : modelsList()) {
145             if (testModelEntry.mModelName.equals(name)) {
146                 return testModelEntry;
147             }
148         }
149         throw new IllegalArgumentException("Unknown TestModelEntry named " + name);
150     }
151 
152 }
153