1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.cts; 18 19 import static junit.framework.TestCase.assertFalse; 20 21 import android.app.Activity; 22 import android.util.Log; 23 import android.util.Pair; 24 25 import androidx.test.filters.LargeTest; 26 import androidx.test.filters.RequiresDevice; 27 import androidx.test.rule.ActivityTestRule; 28 29 import com.android.nn.benchmark.core.BenchmarkException; 30 import com.android.nn.benchmark.core.BenchmarkResult; 31 import com.android.nn.benchmark.core.InferenceInOutSequence; 32 import com.android.nn.benchmark.core.InferenceResult; 33 import com.android.nn.benchmark.core.NNTestBase; 34 import com.android.nn.benchmark.core.TestModels; 35 36 import org.junit.Before; 37 import org.junit.Rule; 38 import org.junit.Test; 39 import org.junit.runner.RunWith; 40 import org.junit.runners.Parameterized; 41 import org.junit.runners.Parameterized.Parameters; 42 43 import java.io.IOException; 44 import java.util.ArrayList; 45 import java.util.Collections; 46 import java.util.List; 47 48 /** 49 * Tests the accuracy of the model outputs. 50 */ 51 @RunWith(Parameterized.class) 52 public class NNAccuracyTest { 53 protected static final String TAG = NNAccuracyTest.class.getSimpleName(); 54 55 @Rule 56 public ActivityTestRule<NNAccuracyActivity> mActivityRule = 57 new ActivityTestRule<>(NNAccuracyActivity.class); 58 59 @Parameterized.Parameter(0) 60 public TestModels.TestModelEntry mModel; 61 62 private Activity mActivity; 63 64 // TODO(vddang): Add mobilenet_v1_0.25_128_quant_topk_aosp 65 private static final String[] MODEL_NAMES = new String[]{ 66 "tts_float", 67 "asr_float", 68 "mobilenet_v1_1.0_224_quant_topk_aosp", 69 "mobilenet_v1_1.0_224_topk_aosp", 70 "mobilenet_v1_0.75_192_quant_topk_aosp", 71 "mobilenet_v1_0.75_192_topk_aosp", 72 "mobilenet_v1_0.5_160_quant_topk_aosp", 73 "mobilenet_v1_0.5_160_topk_aosp", 74 "mobilenet_v1_0.25_128_topk_aosp", 75 "mobilenet_v2_0.35_128_topk_aosp", 76 "mobilenet_v2_0.5_160_topk_aosp", 77 "mobilenet_v2_0.75_192_topk_aosp", 78 "mobilenet_v2_1.0_224_quant_topk_aosp", 79 "mobilenet_v2_1.0_224_topk_aosp", 80 }; 81 82 @Parameters(name = "{0}") modelsList()83 public static List<TestModels.TestModelEntry> modelsList() { 84 List<TestModels.TestModelEntry> models = new ArrayList<>(); 85 for (String modelName : MODEL_NAMES) { 86 models.add(TestModels.getModelByName(modelName)); 87 } 88 return Collections.unmodifiableList(models); 89 } 90 91 @Before setUp()92 public void setUp() throws Exception { 93 mActivity = mActivityRule.getActivity(); 94 } 95 96 @Test 97 @RequiresDevice 98 @LargeTest testNNAPI()99 public void testNNAPI() throws BenchmarkException, IOException { 100 List<String> accelerators = new ArrayList<>(); 101 NNTestBase.getAcceleratorNames(accelerators); 102 for (String accelerator : accelerators) { 103 if (accelerator.equals("nnapi-reference")) { // Skip. 104 continue; 105 } 106 107 try (NNTestBase test = mModel.createNNTestBase(/*useNNAPI=*/true, 108 /*enableIntermediateTensorsDump=*/false)) { 109 test.setNNApiDeviceName(accelerator); 110 if (!test.setupModel(mActivity)) { 111 Log.d(TAG, String.format( 112 "Cannot initialise test '%s' on accelerator %s, skipping", 113 mModel.mModelName, accelerator)); 114 continue; 115 } 116 Pair<List<InferenceInOutSequence>, List<InferenceResult>> inferenceResults = 117 test.runBenchmarkCompleteInputSet(/*setRepeat=*/1, /*timeoutSec=*/3600); 118 BenchmarkResult benchmarkResult = 119 BenchmarkResult.fromInferenceResults( 120 mModel.mModelName, 121 BenchmarkResult.BACKEND_TFLITE_NNAPI, 122 inferenceResults.first, 123 inferenceResults.second, 124 test.getEvaluator()); 125 assertFalse(benchmarkResult.hasValidationErrors()); 126 } 127 } 128 } 129 } 130