1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.crashtest.core.test;
18 
19 import static java.util.concurrent.TimeUnit.MILLISECONDS;
20 
21 import android.annotation.SuppressLint;
22 import android.content.Context;
23 import android.content.Intent;
24 import android.util.Log;
25 
26 import com.android.nn.benchmark.core.Processor;
27 import com.android.nn.crashtest.core.CrashTest;
28 import com.android.nn.crashtest.core.CrashTestCoordinator.CrashTestIntentInitializer;
29 import com.android.nn.benchmark.core.TfLiteBackend;
30 
31 import java.time.Duration;
32 import java.util.ArrayList;
33 import java.util.Collections;
34 import java.util.HashSet;
35 import java.util.List;
36 import java.util.Optional;
37 import java.util.Set;
38 import java.util.concurrent.CountDownLatch;
39 import java.util.concurrent.ExecutionException;
40 import java.util.concurrent.ExecutorService;
41 import java.util.concurrent.Executors;
42 import java.util.concurrent.Future;
43 
44 public class RunModelsInParallel implements CrashTest {
45 
46     private static final String MODELS = "models";
47     private static final String DURATION = "duration";
48     private static final String THREADS = "thread_counts";
49     private static final String TEST_NAME = "test_name";
50     private static final String ACCELERATOR_NAME = "accelerator_name";
51     private static final String IGNORE_UNSUPPORTED_MODELS = "ignore_unsupported_models";
52     private static final String RUN_MODEL_COMPILATION_ONLY = "run_model_compilation_only";
53     private static final String MEMORY_MAP_MODEL = "memory_map_model";
54 
55     private final Set<Processor> activeTests = new HashSet<>();
56     private final List<Boolean> mTestCompletionResults = Collections.synchronizedList(
57             new ArrayList<>());
58     private long mTestDurationMillis = 0;
59     private int mThreadCount = 0;
60     private int[] mTestList = new int[0];
61     private String mTestName;
62     private String mAcceleratorName;
63     private boolean mIgnoreUnsupportedModels;
64     private Context mContext;
65     private boolean mRunModelCompilationOnly;
66     private ExecutorService mExecutorService = null;
67     private CountDownLatch mParallelTestComplete;
68     private ProgressListener mProgressListener;
69     private boolean mMmapModel;
70 
intentInitializer(int[] models, int threadCount, Duration duration, String testName, String acceleratorName, boolean ignoreUnsupportedModels, boolean runModelCompilationOnly, boolean mmapModel)71     static public CrashTestIntentInitializer intentInitializer(int[] models, int threadCount,
72             Duration duration, String testName, String acceleratorName,
73             boolean ignoreUnsupportedModels,
74             boolean runModelCompilationOnly, boolean mmapModel) {
75         return intent -> {
76             intent.putExtra(MODELS, models);
77             intent.putExtra(DURATION, duration.toMillis());
78             intent.putExtra(THREADS, threadCount);
79             intent.putExtra(TEST_NAME, testName);
80             intent.putExtra(ACCELERATOR_NAME, acceleratorName);
81             intent.putExtra(IGNORE_UNSUPPORTED_MODELS, ignoreUnsupportedModels);
82             intent.putExtra(RUN_MODEL_COMPILATION_ONLY, runModelCompilationOnly);
83             intent.putExtra(MEMORY_MAP_MODEL, mmapModel);
84         };
85     }
86 
87     @Override
init(Context context, Intent configParams, Optional<ProgressListener> progressListener)88     public void init(Context context, Intent configParams,
89             Optional<ProgressListener> progressListener) {
90         mTestList = configParams.getIntArrayExtra(MODELS);
91         mThreadCount = configParams.getIntExtra(THREADS, 10);
92         mTestDurationMillis = configParams.getLongExtra(DURATION, 1000 * 60 * 10);
93         mTestName = configParams.getStringExtra(TEST_NAME);
94         mAcceleratorName = configParams.getStringExtra(ACCELERATOR_NAME);
95         mIgnoreUnsupportedModels = mAcceleratorName != null && configParams.getBooleanExtra(
96                 IGNORE_UNSUPPORTED_MODELS, false);
97         mRunModelCompilationOnly = configParams.getBooleanExtra(RUN_MODEL_COMPILATION_ONLY, false);
98         mMmapModel = configParams.getBooleanExtra(MEMORY_MAP_MODEL, false);
99         mContext = context;
100         mProgressListener = progressListener.orElseGet(() -> (Optional<String> message) -> {
101             Log.v(CrashTest.TAG, message.orElse("."));
102         });
103         mExecutorService = Executors.newFixedThreadPool(mThreadCount);
104         mTestCompletionResults.clear();
105     }
106 
107     @Override
call()108     public Optional<String> call() {
109         mParallelTestComplete = new CountDownLatch(mThreadCount);
110         for (int i = 0; i < mThreadCount; i++) {
111             Processor testProcessor = createSubTestRunner(mTestList, i);
112 
113             activeTests.add(testProcessor);
114             mExecutorService.submit(testProcessor);
115         }
116 
117         return completedSuccessfully();
118     }
119 
createSubTestRunner(final int[] testList, final int testIndex)120     private Processor createSubTestRunner(final int[] testList, final int testIndex) {
121         final Processor result = new Processor(mContext, new Processor.Callback() {
122             @SuppressLint("DefaultLocale")
123             @Override
124             public void onBenchmarkFinish(boolean ok) {
125                 notifyProgress("Test '%s': Benchmark #%d completed %s", mTestName, testIndex,
126                         ok ? "successfully" : "with failure");
127                 mTestCompletionResults.add(ok);
128                 mParallelTestComplete.countDown();
129             }
130 
131             @Override
132             public void onStatusUpdate(int testNumber, int numTests, String modelName) {
133             }
134         }, testList);
135         result.setTfLiteBackend(TfLiteBackend.NNAPI);
136         result.setCompleteInputSet(false);
137         result.setNnApiAcceleratorName(mAcceleratorName);
138         result.setIgnoreUnsupportedModels(mIgnoreUnsupportedModels);
139         result.setRunModelCompilationOnly(mRunModelCompilationOnly);
140         result.setMmapModel(mMmapModel);
141         return result;
142     }
143 
endTests()144     private void endTests() {
145         ExecutorService terminatorsThreadPool = Executors.newFixedThreadPool(activeTests.size());
146         List<Future<?>> terminationCommands = new ArrayList<>();
147         for (final Processor test : activeTests) {
148             // Exit will block until the thread is completed
149             terminationCommands.add(terminatorsThreadPool.submit(
150                     () -> test.exitWithTimeout(Duration.ofSeconds(20).toMillis())));
151         }
152         terminationCommands.forEach(terminationCommand -> {
153             try {
154                 terminationCommand.get();
155             } catch (ExecutionException e) {
156                 Log.w(TAG, "Failure while waiting for completion of tests", e);
157             } catch (InterruptedException e) {
158                 Thread.interrupted();
159             }
160         });
161     }
162 
163     @SuppressLint("DefaultLocale")
notifyProgress(String messageFormat, Object... args)164     void notifyProgress(String messageFormat, Object... args) {
165         mProgressListener.testProgress(Optional.of(String.format(messageFormat, args)));
166     }
167 
168     // This method blocks until the tests complete and returns true if all tests completed
169     // successfully
170     @SuppressLint("DefaultLocale")
completedSuccessfully()171     private Optional<String> completedSuccessfully() {
172         try {
173             boolean testsEnded = mParallelTestComplete.await(mTestDurationMillis, MILLISECONDS);
174             if (!testsEnded) {
175                 Log.i(TAG,
176                         String.format(
177                                 "Test '%s': Tests are not completed (they might have been "
178                                         + "designed to run "
179                                         + "indefinitely. Forcing termination.", mTestName));
180                 endTests();
181             }
182         } catch (InterruptedException ignored) {
183             Thread.currentThread().interrupt();
184         }
185 
186         final long failedTestCount = mTestCompletionResults.stream().filter(
187                 testResult -> !testResult).count();
188         if (failedTestCount > 0) {
189             String failureMsg = String.format("Test '%s': %d out of %d test failed", mTestName,
190                     failedTestCount,
191                     mTestCompletionResults.size());
192             Log.w(CrashTest.TAG, failureMsg);
193             return failure(failureMsg);
194         } else {
195             Log.i(CrashTest.TAG,
196                     String.format("Test '%s': Test completed successfully", mTestName));
197             return success();
198         }
199     }
200 }
201