1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import java.util.ArrayList; 20 import java.util.List; 21 import java.util.concurrent.atomic.AtomicReference; 22 23 /** Information about available benchmarking models */ 24 public class TestModels { 25 /** Entry for a single benchmarking model */ 26 public static class TestModelEntry { 27 /** Unique model name, used to find benchmark data */ 28 public final String mModelName; 29 30 /** Expected inference performance in seconds */ 31 public final float mBaselineSec; 32 33 /** Shape of input data */ 34 public final int[] mInputShape; 35 36 /** File pair asset input/output pairs */ 37 public final InferenceInOutSequence.FromAssets[] mInOutAssets; 38 39 /** Dataset inputs */ 40 public final InferenceInOutSequence.FromDataset[] mInOutDatasets; 41 42 /** Readable name for test output */ 43 public final String mTestName; 44 45 /** Name of model file, so that the same file can be reused */ 46 public final String mModelFile; 47 48 /** The evaluator to use for validating the results. */ 49 public final EvaluatorConfig mEvaluator; 50 51 /** Min SDK version that the model can run on. */ 52 public final int mMinSdkVersion; 53 54 /* Number of bytes per input data entry */ 55 public final int mInDataSize; 56 TestModelEntry(String modelName, float baselineSec, int[] inputShape, InferenceInOutSequence.FromAssets[] inOutAssets, InferenceInOutSequence.FromDataset[] inOutDatasets, String testName, String modelFile, EvaluatorConfig evaluator, int minSdkVersion, int inDataSize)57 public TestModelEntry(String modelName, float baselineSec, int[] inputShape, 58 InferenceInOutSequence.FromAssets[] inOutAssets, 59 InferenceInOutSequence.FromDataset[] inOutDatasets, String testName, 60 String modelFile, 61 EvaluatorConfig evaluator, int minSdkVersion, int inDataSize) { 62 mModelName = modelName; 63 mBaselineSec = baselineSec; 64 mInputShape = inputShape; 65 mInOutAssets = inOutAssets; 66 mInOutDatasets = inOutDatasets; 67 mTestName = testName; 68 mModelFile = modelFile; 69 mEvaluator = evaluator; 70 mMinSdkVersion = minSdkVersion; 71 mInDataSize = inDataSize; 72 } 73 createNNTestBase()74 public NNTestBase createNNTestBase() { 75 return new NNTestBase(mModelName, mModelFile, mInputShape, mInOutAssets, mInOutDatasets, 76 mEvaluator, mMinSdkVersion); 77 } 78 createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump)79 public NNTestBase createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump) { 80 return createNNTestBase(tfLiteBackend, enableIntermediateTensorsDump, /*mmapModel=*/false); 81 } 82 83 // Used by CTS tests. createNNTestBase(boolean useNNAPI, boolean enableIntermediateTensorsDump)84 public NNTestBase createNNTestBase(boolean useNNAPI, boolean enableIntermediateTensorsDump) { 85 TfLiteBackend tfLiteBackend = useNNAPI ? TfLiteBackend.NNAPI : TfLiteBackend.CPU; 86 return createNNTestBase(tfLiteBackend, enableIntermediateTensorsDump, /*mmapModel=*/false); 87 } 88 createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump, boolean mmapModel)89 public NNTestBase createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump, 90 boolean mmapModel) { 91 NNTestBase test = createNNTestBase(); 92 test.setTfLiteBackend(tfLiteBackend); 93 test.enableIntermediateTensorsDump(enableIntermediateTensorsDump); 94 test.setMmapModel(mmapModel); 95 return test; 96 } 97 toString()98 public String toString() { 99 return mModelName; 100 } 101 getTestName()102 public String getTestName() { 103 return mTestName; 104 } 105 106 withDisabledEvaluation()107 public TestModelEntry withDisabledEvaluation() { 108 return new TestModelEntry(mModelName, mBaselineSec, mInputShape, mInOutAssets, 109 mInOutDatasets, mTestName, mModelFile, 110 null, // Disable evaluation. 111 mMinSdkVersion, mInDataSize); 112 } 113 } 114 115 static private final List<TestModelEntry> sTestModelEntryList = new ArrayList<>(); 116 static private final AtomicReference<List<TestModelEntry>> frozenEntries = 117 new AtomicReference<>(null); 118 119 120 /** Add new benchmark model. */ registerModel(TestModelEntry model)121 static public void registerModel(TestModelEntry model) { 122 if (frozenEntries.get() != null) { 123 throw new IllegalStateException("Can't register new models after its list is frozen"); 124 } 125 sTestModelEntryList.add(model); 126 } 127 isListFrozen()128 public static boolean isListFrozen() { 129 return frozenEntries.get() != null; 130 } 131 132 /** 133 * Fetch list of test models. 134 * 135 * If this method was called at least once, then it's impossible to register new models. 136 */ modelsList()137 static public List<TestModelEntry> modelsList() { 138 frozenEntries.compareAndSet(null, sTestModelEntryList); 139 return frozenEntries.get(); 140 } 141 142 /** Fetch model by its name. */ getModelByName(String name)143 static public TestModelEntry getModelByName(String name) { 144 for (TestModelEntry testModelEntry : modelsList()) { 145 if (testModelEntry.mModelName.equals(name)) { 146 return testModelEntry; 147 } 148 } 149 throw new IllegalArgumentException("Unknown TestModelEntry named " + name); 150 } 151 152 } 153