1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import java.util.List; 20 import java.util.ArrayList; 21 import java.util.Collections; 22 23 /** Information about available benchmarking models */ 24 public class TestModels { 25 /** Entry for a single benchmarking model */ 26 public static class TestModelEntry { 27 /** Unique model name, used to find benchmark data */ 28 public final String mModelName; 29 30 /** Expected inference performance in seconds */ 31 public final float mBaselineSec; 32 33 /** Shape of input data */ 34 public final int[] mInputShape; 35 36 /** File pair asset input/output pairs */ 37 public final InferenceInOutSequence.FromAssets[] mInOutAssets; 38 39 /** Dataset inputs */ 40 public final InferenceInOutSequence.FromDataset[] mInOutDatasets; 41 42 /** Readable name for test output */ 43 public final String mTestName; 44 45 /** Name of model file, so that the same file can be reused */ 46 public final String mModelFile; 47 48 /** The evaluator to use for validating the results. */ 49 public final EvaluatorConfig mEvaluator; 50 51 /** Min SDK version that the model can run on. */ 52 public final int mMinSdkVersion; 53 TestModelEntry(String modelName, float baselineSec, int[] inputShape, InferenceInOutSequence.FromAssets[] inOutAssets, InferenceInOutSequence.FromDataset[] inOutDatasets, String testName, String modelFile, EvaluatorConfig evaluator, int minSdkVersion)54 public TestModelEntry(String modelName, float baselineSec, int[] inputShape, 55 InferenceInOutSequence.FromAssets[] inOutAssets, 56 InferenceInOutSequence.FromDataset[] inOutDatasets, 57 String testName, String modelFile, EvaluatorConfig evaluator, 58 int minSdkVersion) { 59 mModelName = modelName; 60 mBaselineSec = baselineSec; 61 mInputShape = inputShape; 62 mInOutAssets = inOutAssets; 63 mInOutDatasets = inOutDatasets; 64 mTestName = testName; 65 mModelFile = modelFile; 66 mEvaluator = evaluator; 67 mMinSdkVersion = minSdkVersion; 68 } 69 createNNTestBase()70 public NNTestBase createNNTestBase() { 71 return new NNTestBase(mModelName, mModelFile, mInputShape, mInOutAssets, mInOutDatasets, 72 mEvaluator, mMinSdkVersion); 73 } 74 createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump)75 public NNTestBase createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump) { 76 NNTestBase test = createNNTestBase(); 77 test.useNNApi(useNNApi); 78 test.enableIntermediateTensorsDump(enableIntermediateTensorsDump); 79 return test; 80 } 81 toString()82 public String toString() { 83 return mModelName; 84 } 85 getTestName()86 public String getTestName() { 87 return mTestName; 88 } 89 } 90 91 static private List<TestModelEntry> sTestModelEntryList = new ArrayList<>(); 92 static private volatile boolean sTestModelEntryListFrozen = false; 93 94 /** Add new benchmark model. */ registerModel(TestModelEntry model)95 static public void registerModel(TestModelEntry model) { 96 if (sTestModelEntryListFrozen) { 97 throw new IllegalStateException("Can't register new models after its list is frozen"); 98 } 99 sTestModelEntryList.add(model); 100 } 101 102 /** Fetch list of test models. 103 * 104 * If this method was called at least once, then it's impossible to register new models. 105 */ modelsList()106 static public List<TestModelEntry> modelsList() { 107 if (!sTestModelEntryListFrozen) { 108 // If this method was called once, make models list unmodifiable 109 synchronized (TestModels.class) { 110 if (!sTestModelEntryListFrozen) { 111 sTestModelEntryList = Collections.unmodifiableList(sTestModelEntryList); 112 sTestModelEntryListFrozen = true; 113 } 114 } 115 } 116 return sTestModelEntryList; 117 } 118 119 /** Fetch model by its name. */ getModelByName(String name)120 static public TestModelEntry getModelByName(String name) { 121 for (TestModelEntry testModelEntry : modelsList()) { 122 if (testModelEntry.mModelName.equals(name)) { 123 return testModelEntry; 124 } 125 } 126 throw new IllegalArgumentException("Unknown TestModelEntry named " + name); 127 } 128 } 129