1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.cts;
18 
19 import static junit.framework.TestCase.assertFalse;
20 
21 import android.app.Activity;
22 import android.util.Log;
23 import android.util.Pair;
24 
25 import androidx.test.filters.LargeTest;
26 import androidx.test.filters.RequiresDevice;
27 import androidx.test.rule.ActivityTestRule;
28 
29 import com.android.nn.benchmark.core.BenchmarkException;
30 import com.android.nn.benchmark.core.BenchmarkResult;
31 import com.android.nn.benchmark.core.InferenceInOutSequence;
32 import com.android.nn.benchmark.core.InferenceResult;
33 import com.android.nn.benchmark.core.NNTestBase;
34 import com.android.nn.benchmark.core.TestModels;
35 
36 import org.junit.Before;
37 import org.junit.Rule;
38 import org.junit.Test;
39 import org.junit.runner.RunWith;
40 import org.junit.runners.Parameterized;
41 import org.junit.runners.Parameterized.Parameters;
42 
43 import java.io.IOException;
44 import java.util.ArrayList;
45 import java.util.Collections;
46 import java.util.List;
47 
48 /**
49  * Tests the accuracy of the model outputs.
50  */
51 @RunWith(Parameterized.class)
52 public class NNAccuracyTest {
53     protected static final String TAG = NNAccuracyTest.class.getSimpleName();
54 
55     @Rule
56     public ActivityTestRule<NNAccuracyActivity> mActivityRule =
57             new ActivityTestRule<>(NNAccuracyActivity.class);
58 
59     @Parameterized.Parameter(0)
60     public TestModels.TestModelEntry mModel;
61 
62     private Activity mActivity;
63 
64     // TODO(vddang): Add mobilenet_v1_0.25_128_quant_topk_aosp
65     private static final String[] MODEL_NAMES = new String[]{
66             "tts_float",
67             "asr_float",
68             "mobilenet_v1_1.0_224_quant_topk_aosp",
69             "mobilenet_v1_1.0_224_topk_aosp",
70             "mobilenet_v1_0.75_192_quant_topk_aosp",
71             "mobilenet_v1_0.75_192_topk_aosp",
72             "mobilenet_v1_0.5_160_quant_topk_aosp",
73             "mobilenet_v1_0.5_160_topk_aosp",
74             "mobilenet_v1_0.25_128_topk_aosp",
75             "mobilenet_v2_0.35_128_topk_aosp",
76             "mobilenet_v2_0.5_160_topk_aosp",
77             "mobilenet_v2_0.75_192_topk_aosp",
78             "mobilenet_v2_1.0_224_quant_topk_aosp",
79             "mobilenet_v2_1.0_224_topk_aosp",
80     };
81 
82     @Parameters(name = "{0}")
modelsList()83     public static List<TestModels.TestModelEntry> modelsList() {
84         List<TestModels.TestModelEntry> models = new ArrayList<>();
85         for (String modelName : MODEL_NAMES) {
86             models.add(TestModels.getModelByName(modelName));
87         }
88         return Collections.unmodifiableList(models);
89     }
90 
91     @Before
setUp()92     public void setUp() throws Exception {
93         mActivity = mActivityRule.getActivity();
94     }
95 
96     @Test
97     @RequiresDevice
98     @LargeTest
testNNAPI()99     public void testNNAPI() throws BenchmarkException, IOException {
100         List<String> accelerators = new ArrayList<>();
101         NNTestBase.getAcceleratorNames(accelerators);
102         for (String accelerator : accelerators) {
103             if (accelerator.equals("nnapi-reference")) {  // Skip.
104                 continue;
105             }
106 
107             try (NNTestBase test = mModel.createNNTestBase(/*useNNAPI=*/true,
108                         /*enableIntermediateTensorsDump=*/false)) {
109                 test.setNNApiDeviceName(accelerator);
110                 if (!test.setupModel(mActivity)) {
111                     Log.d(TAG, String.format(
112                         "Cannot initialise test '%s' on accelerator %s, skipping",
113                         mModel.mModelName, accelerator));
114                     continue;
115                 }
116                 Pair<List<InferenceInOutSequence>, List<InferenceResult>> inferenceResults =
117                         test.runBenchmarkCompleteInputSet(/*setRepeat=*/1, /*timeoutSec=*/3600);
118                 BenchmarkResult benchmarkResult =
119                         BenchmarkResult.fromInferenceResults(
120                                 mModel.mModelName,
121                                 BenchmarkResult.BACKEND_TFLITE_NNAPI,
122                                 inferenceResults.first,
123                                 inferenceResults.second,
124                                 test.getEvaluator());
125                 assertFalse(benchmarkResult.hasValidationErrors());
126             }
127         }
128     }
129 }
130