1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import java.util.ArrayList; 20 import java.util.List; 21 import java.util.concurrent.atomic.AtomicReference; 22 23 /** Information about available benchmarking models */ 24 public class TestModels { 25 /** Entry for a single benchmarking model */ 26 public static class TestModelEntry { 27 /** Unique model name, used to find benchmark data */ 28 public final String mModelName; 29 30 /** Expected inference performance in seconds */ 31 public final float mBaselineSec; 32 33 /** Shape of input data */ 34 public final int[] mInputShape; 35 36 /** File pair asset input/output pairs */ 37 public final InferenceInOutSequence.FromAssets[] mInOutAssets; 38 39 /** Dataset inputs */ 40 public final InferenceInOutSequence.FromDataset[] mInOutDatasets; 41 42 /** Readable name for test output */ 43 public final String mTestName; 44 45 /** Name of model file, so that the same file can be reused */ 46 public final String mModelFile; 47 48 /** The evaluator to use for validating the results. */ 49 public final EvaluatorConfig mEvaluator; 50 51 /** Min SDK version that the model can run on. */ 52 public final int mMinSdkVersion; 53 TestModelEntry(String modelName, float baselineSec, int[] inputShape, InferenceInOutSequence.FromAssets[] inOutAssets, InferenceInOutSequence.FromDataset[] inOutDatasets, String testName, String modelFile, EvaluatorConfig evaluator, int minSdkVersion)54 public TestModelEntry(String modelName, float baselineSec, int[] inputShape, 55 InferenceInOutSequence.FromAssets[] inOutAssets, 56 InferenceInOutSequence.FromDataset[] inOutDatasets, 57 String testName, String modelFile, EvaluatorConfig evaluator, 58 int minSdkVersion) { 59 mModelName = modelName; 60 mBaselineSec = baselineSec; 61 mInputShape = inputShape; 62 mInOutAssets = inOutAssets; 63 mInOutDatasets = inOutDatasets; 64 mTestName = testName; 65 mModelFile = modelFile; 66 mEvaluator = evaluator; 67 mMinSdkVersion = minSdkVersion; 68 } 69 createNNTestBase()70 public NNTestBase createNNTestBase() { 71 return new NNTestBase(mModelName, mModelFile, mInputShape, mInOutAssets, mInOutDatasets, 72 mEvaluator, mMinSdkVersion); 73 } 74 createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump)75 public NNTestBase createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump) { 76 NNTestBase test = createNNTestBase(); 77 test.useNNApi(useNNApi); 78 test.enableIntermediateTensorsDump(enableIntermediateTensorsDump); 79 return test; 80 } 81 toString()82 public String toString() { 83 return mModelName; 84 } 85 getTestName()86 public String getTestName() { 87 return mTestName; 88 } 89 } 90 91 static private final List<TestModelEntry> sTestModelEntryList = new ArrayList<>(); 92 static private final AtomicReference<List<TestModelEntry>> frozenEntries = new AtomicReference<>(null); 93 94 95 /** Add new benchmark model. */ registerModel(TestModelEntry model)96 static public void registerModel(TestModelEntry model) { 97 if (frozenEntries.get() != null) { 98 throw new IllegalStateException("Can't register new models after its list is frozen"); 99 } 100 sTestModelEntryList.add(model); 101 } 102 103 /** Fetch list of test models. 104 * 105 * If this method was called at least once, then it's impossible to register new models. 106 */ modelsList()107 static public List<TestModelEntry> modelsList() { 108 frozenEntries.compareAndSet(null, sTestModelEntryList); 109 return frozenEntries.get(); 110 } 111 112 /** Fetch model by its name. */ getModelByName(String name)113 static public TestModelEntry getModelByName(String name) { 114 for (TestModelEntry testModelEntry : modelsList()) { 115 if (testModelEntry.mModelName.equals(name)) { 116 return testModelEntry; 117 } 118 } 119 throw new IllegalArgumentException("Unknown TestModelEntry named " + name); 120 } 121 } 122