1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <HalInterfaces.h>
18 #include <SampleDriver.h>
19 #include <ValidateHal.h>
20 #include <gtest/gtest.h>
21 
22 #include <algorithm>
23 #include <atomic>
24 #include <cassert>
25 #include <functional>
26 #include <memory>
27 #include <string>
28 #include <thread>
29 #include <tuple>
30 #include <vector>
31 
32 #include "CompilationBuilder.h"
33 #include "ExecutionBurstServer.h"
34 #include "ExecutionCallback.h"
35 #include "HalUtils.h"
36 #include "Manager.h"
37 #include "ModelBuilder.h"
38 #include "NeuralNetworks.h"
39 #include "PreparedModelCallback.h"
40 #include "TestNeuralNetworksWrapper.h"
41 
42 namespace android {
43 
44 namespace V1_0 = ::android::hardware::neuralnetworks::V1_0;
45 namespace V1_1 = ::android::hardware::neuralnetworks::V1_1;
46 namespace V1_2 = ::android::hardware::neuralnetworks::V1_2;
47 namespace V1_3 = ::android::hardware::neuralnetworks::V1_3;
48 using CompilationBuilder = nn::CompilationBuilder;
49 using Device = nn::Device;
50 using SharedDevice = nn::SharedDevice;
51 using DeviceManager = nn::DeviceManager;
52 using HidlModel = V1_3::Model;
53 using PreparedModelCallback = nn::PreparedModelCallback;
54 using SampleDriver = nn::sample_driver::SampleDriver;
55 using WrapperCompilation = nn::test_wrapper::Compilation;
56 using WrapperEvent = nn::test_wrapper::Event;
57 using WrapperExecution = nn::test_wrapper::Execution;
58 using WrapperModel = nn::test_wrapper::Model;
59 using WrapperOperandType = nn::test_wrapper::OperandType;
60 using WrapperResult = nn::test_wrapper::Result;
61 using WrapperType = nn::test_wrapper::Type;
62 using nn::convertToV1_0;
63 using nn::convertToV1_3;
64 using nn::ErrorStatus;
65 
66 template <typename T>
67 using MQDescriptorSync = hardware::MQDescriptorSync<T>;
68 
69 namespace {
70 
71 const V1_2::Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
72 
73 // Wraps the latest version of IPreparedModel to allow dummying up the execution status,
74 // and control when the execution finishes.
75 class TestPreparedModelLatest : public V1_3::IPreparedModel {
76    public:
77     // If errorStatus is NONE, then execute behaves normally (and sends back
78     // the actual execution status).  Otherwise, don't bother to execute, and
79     // just send back errorStatus (as the execution status, not the launch
80     // status).
TestPreparedModelLatest(sp<V1_0::IPreparedModel> preparedModel,V1_3::ErrorStatus errorStatus)81     TestPreparedModelLatest(sp<V1_0::IPreparedModel> preparedModel, V1_3::ErrorStatus errorStatus)
82         : mPreparedModelV1_0(preparedModel),
83           mPreparedModelV1_2(V1_2::IPreparedModel::castFrom(preparedModel).withDefault(nullptr)),
84           mPreparedModelV1_3(V1_3::IPreparedModel::castFrom(preparedModel).withDefault(nullptr)),
85           mErrorStatus(errorStatus) {}
86 
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)87     hardware::Return<V1_0::ErrorStatus> execute(
88             const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override {
89         CHECK(mPreparedModelV1_0 != nullptr) << "V1_0 prepared model is nullptr.";
90         std::thread([this, request, callback] {
91             dummyExecution();
92             if (mErrorStatus == V1_3::ErrorStatus::NONE) {
93                 // Note that we lose the actual launch status.
94                 (void)mPreparedModelV1_0->execute(request, callback);
95             } else {
96                 callback->notify(convertToV1_0(mErrorStatus));
97             }
98         }).detach();
99         return V1_0::ErrorStatus::NONE;
100     }
101 
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)102     hardware::Return<V1_0::ErrorStatus> execute_1_2(
103             const V1_0::Request& request, V1_2::MeasureTiming measure,
104             const sp<V1_2::IExecutionCallback>& callback) override {
105         CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
106         std::thread([this, request, measure, callback] {
107             dummyExecution();
108             if (mErrorStatus == V1_3::ErrorStatus::NONE) {
109                 // Note that we lose the actual launch status.
110                 (void)mPreparedModelV1_2->execute_1_2(request, measure, callback);
111             } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
112                 V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
113                 callback->notify_1_2(convertToV1_0(mErrorStatus), {shape}, kBadTiming);
114             } else {
115                 callback->notify_1_2(convertToV1_0(mErrorStatus), {}, kBadTiming);
116             }
117         }).detach();
118         return V1_0::ErrorStatus::NONE;
119     }
120 
execute_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)121     hardware::Return<V1_3::ErrorStatus> execute_1_3(
122             const V1_3::Request& request, V1_2::MeasureTiming measure,
123             const V1_3::OptionalTimePoint& deadline,
124             const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
125             const sp<V1_3::IExecutionCallback>& callback) override {
126         CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
127         std::thread([this, request, measure, deadline, loopTimeoutDuration, callback] {
128             dummyExecution();
129             if (mErrorStatus == V1_3::ErrorStatus::NONE) {
130                 // Note that we lose the actual launch status.
131                 (void)mPreparedModelV1_3->execute_1_3(request, measure, deadline,
132                                                       loopTimeoutDuration, callback);
133             } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
134                 V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
135                 callback->notify_1_3(mErrorStatus, {shape}, kBadTiming);
136             } else {
137                 callback->notify_1_3(mErrorStatus, {}, kBadTiming);
138             }
139         }).detach();
140         return V1_3::ErrorStatus::NONE;
141     }
142 
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)143     hardware::Return<void> executeSynchronously(const V1_0::Request& request,
144                                                 V1_2::MeasureTiming measure,
145                                                 executeSynchronously_cb cb) override {
146         CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
147         dummyExecution();
148         if (mErrorStatus == V1_3::ErrorStatus::NONE) {
149             return mPreparedModelV1_2->executeSynchronously(request, measure, cb);
150         } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
151             V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
152             cb(convertToV1_0(mErrorStatus), {shape}, kBadTiming);
153             return hardware::Void();
154         } else {
155             cb(convertToV1_0(mErrorStatus), {}, kBadTiming);
156             return hardware::Void();
157         }
158     }
159 
executeSynchronously_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)160     hardware::Return<void> executeSynchronously_1_3(
161             const V1_3::Request& request, V1_2::MeasureTiming measure,
162             const V1_3::OptionalTimePoint& deadline,
163             const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
164             executeSynchronously_1_3_cb cb) override {
165         CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
166         dummyExecution();
167         if (mErrorStatus == V1_3::ErrorStatus::NONE) {
168             return mPreparedModelV1_3->executeSynchronously_1_3(request, measure, deadline,
169                                                                 loopTimeoutDuration, cb);
170         } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
171             V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
172             cb(mErrorStatus, {shape}, kBadTiming);
173             return hardware::Void();
174         } else {
175             cb(mErrorStatus, {}, kBadTiming);
176             return hardware::Void();
177         }
178     }
179 
180     // ExecutionBurstServer::create has an overload that will use
181     // IPreparedModel::executeSynchronously(), so we can rely on that, rather
182     // than having to implement ExecutionBurstServer::IExecutorWithCache.
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)183     hardware::Return<void> configureExecutionBurst(
184             const sp<V1_2::IBurstCallback>& callback,
185             const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
186             const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
187             configureExecutionBurst_cb cb) override {
188         CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
189         if (mErrorStatus == V1_3::ErrorStatus::NONE) {
190             const sp<V1_2::IBurstContext> burst =
191                     nn::ExecutionBurstServer::create(callback, requestChannel, resultChannel, this);
192 
193             cb(burst == nullptr ? V1_0::ErrorStatus::GENERAL_FAILURE : V1_0::ErrorStatus::NONE,
194                burst);
195             return hardware::Void();
196         } else {
197             cb(convertToV1_0(mErrorStatus), nullptr);
198             return hardware::Void();
199         }
200     }
201 
202     // Note, due to the limitation of SampleDriver implementation, the call is
203     // synchronous.  The test code that exercises this implementation of
204     // SampleDriver is written with that in mind.  Therefore, this
205     // implementation is synchronous also.  If the SampleDriver is updated to
206     // return real sync fence, this must be updated.
executeFenced(const V1_3::Request & request,const hardware::hidl_vec<hardware::hidl_handle> & waitFor,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const V1_3::OptionalTimeoutDuration & duration,executeFenced_cb cb)207     hardware::Return<void> executeFenced(const V1_3::Request& request,
208                                          const hardware::hidl_vec<hardware::hidl_handle>& waitFor,
209                                          V1_2::MeasureTiming measure,
210                                          const V1_3::OptionalTimePoint& deadline,
211                                          const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
212                                          const V1_3::OptionalTimeoutDuration& duration,
213                                          executeFenced_cb cb) override {
214         CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
215         CHECK(mErrorStatus != V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE)
216                 << "executeFenced does not support dynamic output shape";
217         dummyExecution();
218         if (mErrorStatus == V1_3::ErrorStatus::NONE) {
219             return mPreparedModelV1_3->executeFenced(request, waitFor, measure, deadline,
220                                                      loopTimeoutDuration, duration, cb);
221         } else {
222             // Due to the limitations of the SampleDriver, all failures look
223             // like launch failures.  If the SampleDriver is updated to return
224             // real sync fences, this must be updated.
225             cb(mErrorStatus, hardware::hidl_handle(nullptr), nullptr);
226         }
227         return hardware::Void();
228     }
229 
230     // We can place the TestPreparedModelLatest system in a "pause" mode where
231     // no execution will complete until the system is taken out of that mode.
232     // Initially, the system is not in that mode.
pauseExecutions(bool v)233     static void pauseExecutions(bool v) { mPauseExecutions.store(v); }
234 
235     // This function is only guaranteed to work in the following pattern:
236     // Consider thread A as primary thread
237     // - thread A: pauseExecutions(true);
238     // - thread A: launch execution (as thread B)
239     // - thread A: waitForExecutionToBegin(), block until call to dummyExecution by
240     //                                        thread B makes mExecutionsInFlight nonzero
241     // - thread B: dummyExecution(), which makes mExecutionsInFlight nonzero and blocks
242     //                               until thread A calls pauseExecutions(false)
243     // - thread A: waitForExecutionToBegin() returns
244     // - thread A: pauseExecutions(false), allowing dummyExecution() on thread B to continue
245     // - thread B: dummyExecution() zeroes mExecutionsInFlight and returns
246     // - thread B: thread exits
waitForExecutionToBegin()247     static void waitForExecutionToBegin() {
248         CHECK(mPauseExecutions.load());
249         while (mExecutionsInFlight.load() == 0) {
250         }
251     }
252 
253    private:
254     const sp<V1_0::IPreparedModel> mPreparedModelV1_0;
255     const sp<V1_2::IPreparedModel> mPreparedModelV1_2;
256     const sp<V1_3::IPreparedModel> mPreparedModelV1_3;
257     V1_3::ErrorStatus mErrorStatus;
258 
259     static std::atomic<bool> mPauseExecutions;
260     static std::atomic<unsigned int> mExecutionsInFlight;
261 
dummyExecution()262     static void dummyExecution() {
263         CHECK_EQ(mExecutionsInFlight.fetch_add(1), 0u) << "We do not support concurrent executions";
264         while (mPauseExecutions.load()) {
265         }
266         mExecutionsInFlight.fetch_sub(1);
267     }
268 };
269 std::atomic<bool> TestPreparedModelLatest::mPauseExecutions = false;
270 std::atomic<unsigned int> TestPreparedModelLatest::mExecutionsInFlight = 0;
271 
272 using TestPreparedModel13 = TestPreparedModelLatest;
273 
274 // Like TestPreparedModelLatest, but implementing 1.2
275 class TestPreparedModel12 : public V1_2::IPreparedModel {
276    public:
TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel,V1_3::ErrorStatus errorStatus)277     TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel, V1_3::ErrorStatus errorStatus)
278         : mLatestPreparedModel(new TestPreparedModelLatest(preparedModel, errorStatus)) {}
279 
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)280     hardware::Return<V1_0::ErrorStatus> execute(
281             const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override {
282         return mLatestPreparedModel->execute(request, callback);
283     }
284 
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)285     hardware::Return<V1_0::ErrorStatus> execute_1_2(
286             const V1_0::Request& request, V1_2::MeasureTiming measure,
287             const sp<V1_2::IExecutionCallback>& callback) override {
288         return mLatestPreparedModel->execute_1_2(request, measure, callback);
289     }
290 
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)291     hardware::Return<void> executeSynchronously(const V1_0::Request& request,
292                                                 V1_2::MeasureTiming measure,
293                                                 executeSynchronously_cb cb) override {
294         return mLatestPreparedModel->executeSynchronously(request, measure, cb);
295     }
296 
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)297     hardware::Return<void> configureExecutionBurst(
298             const sp<V1_2::IBurstCallback>& callback,
299             const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
300             const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
301             configureExecutionBurst_cb cb) override {
302         return mLatestPreparedModel->configureExecutionBurst(callback, requestChannel,
303                                                              resultChannel, cb);
304     }
305 
306    private:
307     const sp<V1_3::IPreparedModel> mLatestPreparedModel;
308 };
309 
310 // Like TestPreparedModelLatest, but implementing 1.0
311 class TestPreparedModel10 : public V1_0::IPreparedModel {
312    public:
TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel,V1_3::ErrorStatus errorStatus)313     TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel, V1_3::ErrorStatus errorStatus)
314         : mLatestPreparedModel(new TestPreparedModelLatest(preparedModel, errorStatus)) {}
315 
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)316     hardware::Return<V1_0::ErrorStatus> execute(
317             const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override {
318         return mLatestPreparedModel->execute(request, callback);
319     }
320 
321    private:
322     const sp<V1_3::IPreparedModel> mLatestPreparedModel;
323 };
324 
325 // Behaves like SampleDriver, except that it produces wrapped IPreparedModel.
326 class TestDriver13 : public SampleDriver {
327    public:
328     // Allow dummying up the error status for execution of all models
329     // prepared from this driver.  If errorStatus is NONE, then
330     // execute behaves normally (and sends back the actual execution
331     // status). Otherwise, don't bother to execute, and just send
332     // back errorStatus (as the execution status, not the launch
333     // status).
TestDriver13(const std::string & name,V1_3::ErrorStatus errorStatus)334     TestDriver13(const std::string& name, V1_3::ErrorStatus errorStatus)
335         : SampleDriver(name.c_str()), mErrorStatus(errorStatus) {}
336 
getCapabilities_1_3(getCapabilities_1_3_cb _hidl_cb)337     hardware::Return<void> getCapabilities_1_3(getCapabilities_1_3_cb _hidl_cb) override {
338         android::nn::initVLogMask();
339         const V1_0::PerformanceInfo kPerf = {.execTime = 0.75f, .powerUsage = 0.75f};
340         V1_3::Capabilities capabilities = {
341                 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
342                 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
343                 .operandPerformance =
344                         nn::nonExtensionOperandPerformance<nn::HalVersion::V1_3>(kPerf),
345                 .ifPerformance = kPerf,
346                 .whilePerformance = kPerf};
347         _hidl_cb(V1_3::ErrorStatus::NONE, capabilities);
348         return hardware::Void();
349     }
350 
getSupportedOperations_1_3(const HidlModel & model,getSupportedOperations_1_3_cb cb)351     hardware::Return<void> getSupportedOperations_1_3(const HidlModel& model,
352                                                       getSupportedOperations_1_3_cb cb) override {
353         if (nn::validateModel(model)) {
354             std::vector<bool> supported(model.main.operations.size(), true);
355             cb(V1_3::ErrorStatus::NONE, supported);
356         } else {
357             cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {});
358         }
359         return hardware::Void();
360     }
361 
prepareModel_1_3(const HidlModel & model,V1_1::ExecutionPreference preference,V1_3::Priority priority,const V1_3::OptionalTimePoint & deadline,const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_3::IPreparedModelCallback> & actualCallback)362     hardware::Return<V1_3::ErrorStatus> prepareModel_1_3(
363             const HidlModel& model, V1_1::ExecutionPreference preference, V1_3::Priority priority,
364             const V1_3::OptionalTimePoint& deadline,
365             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
366             const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
367             const nn::HalCacheToken& token,
368             const sp<V1_3::IPreparedModelCallback>& actualCallback) override {
369         sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
370         hardware::Return<V1_3::ErrorStatus> prepareModelReturn = SampleDriver::prepareModel_1_3(
371                 model, preference, priority, deadline, modelCache, dataCache, token, localCallback);
372         if (!prepareModelReturn.isOkUnchecked()) {
373             return prepareModelReturn;
374         }
375         if (prepareModelReturn != V1_3::ErrorStatus::NONE) {
376             actualCallback->notify_1_3(
377                     convertToV1_3(localCallback->getStatus()),
378                     V1_3::IPreparedModel::castFrom(localCallback->getPreparedModel()));
379             return prepareModelReturn;
380         }
381         localCallback->wait();
382         if (localCallback->getStatus() != ErrorStatus::NONE) {
383             actualCallback->notify_1_3(
384                     convertToV1_3(localCallback->getStatus()),
385                     V1_3::IPreparedModel::castFrom(localCallback->getPreparedModel()));
386         } else {
387             actualCallback->notify_1_3(
388                     V1_3::ErrorStatus::NONE,
389                     new TestPreparedModel13(localCallback->getPreparedModel(), mErrorStatus));
390         }
391         return prepareModelReturn;
392     }
393 
prepareModel_1_2(const V1_2::Model & model,V1_1::ExecutionPreference preference,const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_2::IPreparedModelCallback> & actualCallback)394     hardware::Return<V1_0::ErrorStatus> prepareModel_1_2(
395             const V1_2::Model& model, V1_1::ExecutionPreference preference,
396             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
397             const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
398             const nn::HalCacheToken& token,
399             const sp<V1_2::IPreparedModelCallback>& actualCallback) override {
400         sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
401         hardware::Return<V1_0::ErrorStatus> prepareModelReturn = SampleDriver::prepareModel_1_2(
402                 model, preference, modelCache, dataCache, token, localCallback);
403         if (!prepareModelReturn.isOkUnchecked()) {
404             return prepareModelReturn;
405         }
406         if (prepareModelReturn != V1_0::ErrorStatus::NONE) {
407             actualCallback->notify_1_2(
408                     convertToV1_0(localCallback->getStatus()),
409                     V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
410             return prepareModelReturn;
411         }
412         localCallback->wait();
413         if (localCallback->getStatus() != ErrorStatus::NONE) {
414             actualCallback->notify_1_2(
415                     convertToV1_0(localCallback->getStatus()),
416                     V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
417         } else {
418             actualCallback->notify_1_2(
419                     V1_0::ErrorStatus::NONE,
420                     new TestPreparedModel12(localCallback->getPreparedModel(), mErrorStatus));
421         }
422         return prepareModelReturn;
423     }
424 
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)425     hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
426             const V1_1::Model& model, V1_1::ExecutionPreference preference,
427             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
428         sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
429         hardware::Return<V1_0::ErrorStatus> prepareModelReturn =
430                 SampleDriver::prepareModel_1_1(model, preference, localCallback);
431         if (!prepareModelReturn.isOkUnchecked()) {
432             return prepareModelReturn;
433         }
434         if (prepareModelReturn != V1_0::ErrorStatus::NONE) {
435             actualCallback->notify(convertToV1_0(localCallback->getStatus()),
436                                    localCallback->getPreparedModel());
437             return prepareModelReturn;
438         }
439         localCallback->wait();
440         if (localCallback->getStatus() != ErrorStatus::NONE) {
441             actualCallback->notify(convertToV1_0(localCallback->getStatus()),
442                                    localCallback->getPreparedModel());
443         } else {
444             actualCallback->notify(
445                     V1_0::ErrorStatus::NONE,
446                     new TestPreparedModel10(localCallback->getPreparedModel(), mErrorStatus));
447         }
448         return prepareModelReturn;
449     }
450 
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)451     hardware::Return<V1_0::ErrorStatus> prepareModel(
452             const V1_0::Model& model,
453             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
454         return prepareModel_1_1(nn::convertToV1_1(model),
455                                 V1_1::ExecutionPreference::FAST_SINGLE_ANSWER, actualCallback);
456     }
457 
458    private:
459     V1_3::ErrorStatus mErrorStatus;
460 };
461 
462 // Like TestDriver, but implementing 1.2
463 class TestDriver12 : public V1_2::IDevice {
464    public:
TestDriver12(const std::string & name,V1_3::ErrorStatus errorStatus)465     TestDriver12(const std::string& name, V1_3::ErrorStatus errorStatus)
466         : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb)467     hardware::Return<void> getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb) override {
468         return mLatestDriver->getCapabilities_1_2(_hidl_cb);
469     }
getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb)470     hardware::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb) override {
471         return mLatestDriver->getCapabilities_1_1(_hidl_cb);
472     }
getCapabilities(getCapabilities_cb _hidl_cb)473     hardware::Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
474         return mLatestDriver->getCapabilities(_hidl_cb);
475     }
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb _hidl_cb)476     hardware::Return<void> getSupportedOperations_1_2(
477             const V1_2::Model& model, getSupportedOperations_1_2_cb _hidl_cb) override {
478         return mLatestDriver->getSupportedOperations_1_2(model, _hidl_cb);
479     }
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb _hidl_cb)480     hardware::Return<void> getSupportedOperations_1_1(
481             const V1_1::Model& model, getSupportedOperations_1_1_cb _hidl_cb) override {
482         return mLatestDriver->getSupportedOperations_1_1(model, _hidl_cb);
483     }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)484     hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
485                                                   getSupportedOperations_cb _hidl_cb) override {
486         return mLatestDriver->getSupportedOperations(model, _hidl_cb);
487     }
prepareModel_1_2(const V1_2::Model & model,V1_1::ExecutionPreference preference,const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_2::IPreparedModelCallback> & actualCallback)488     hardware::Return<V1_0::ErrorStatus> prepareModel_1_2(
489             const V1_2::Model& model, V1_1::ExecutionPreference preference,
490             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
491             const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
492             const nn::HalCacheToken& token,
493             const sp<V1_2::IPreparedModelCallback>& actualCallback) override {
494         return mLatestDriver->prepareModel_1_2(model, preference, modelCache, dataCache, token,
495                                                actualCallback);
496     }
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)497     hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
498             const V1_1::Model& model, V1_1::ExecutionPreference preference,
499             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
500         return mLatestDriver->prepareModel_1_1(model, preference, actualCallback);
501     }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)502     hardware::Return<V1_0::ErrorStatus> prepareModel(
503             const V1_0::Model& model,
504             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
505         return mLatestDriver->prepareModel(model, actualCallback);
506     }
getStatus()507     hardware::Return<V1_0::DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
getVersionString(getVersionString_cb _hidl_cb)508     hardware::Return<void> getVersionString(getVersionString_cb _hidl_cb) override {
509         return mLatestDriver->getVersionString(_hidl_cb);
510     }
getType(getType_cb _hidl_cb)511     hardware::Return<void> getType(getType_cb _hidl_cb) override {
512         return mLatestDriver->getType(_hidl_cb);
513     }
getSupportedExtensions(getSupportedExtensions_cb _hidl_cb)514     hardware::Return<void> getSupportedExtensions(getSupportedExtensions_cb _hidl_cb) {
515         return mLatestDriver->getSupportedExtensions(_hidl_cb);
516     }
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb _hidl_cb)517     hardware::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb _hidl_cb) {
518         return mLatestDriver->getNumberOfCacheFilesNeeded(_hidl_cb);
519     }
prepareModelFromCache(const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_2::IPreparedModelCallback> & callback)520     hardware::Return<V1_0::ErrorStatus> prepareModelFromCache(
521             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
522             const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
523             const nn::HalCacheToken& token, const sp<V1_2::IPreparedModelCallback>& callback) {
524         return mLatestDriver->prepareModelFromCache(modelCache, dataCache, token, callback);
525     }
526 
527    private:
528     const sp<V1_3::IDevice> mLatestDriver;
529 };
530 
531 // Like TestDriver, but implementing 1.1
532 class TestDriver11 : public V1_1::IDevice {
533    public:
TestDriver11(const std::string & name,V1_3::ErrorStatus errorStatus)534     TestDriver11(const std::string& name, V1_3::ErrorStatus errorStatus)
535         : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb)536     hardware::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb) override {
537         return mLatestDriver->getCapabilities_1_1(_hidl_cb);
538     }
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb _hidl_cb)539     hardware::Return<void> getSupportedOperations_1_1(
540             const V1_1::Model& model, getSupportedOperations_1_1_cb _hidl_cb) override {
541         return mLatestDriver->getSupportedOperations_1_1(model, _hidl_cb);
542     }
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)543     hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
544             const V1_1::Model& model, V1_1::ExecutionPreference preference,
545             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
546         return mLatestDriver->prepareModel_1_1(model, preference, actualCallback);
547     }
getStatus()548     hardware::Return<V1_0::DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
getCapabilities(getCapabilities_cb _hidl_cb)549     hardware::Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
550         return mLatestDriver->getCapabilities(_hidl_cb);
551     }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)552     hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
553                                                   getSupportedOperations_cb _hidl_cb) override {
554         return mLatestDriver->getSupportedOperations(model, _hidl_cb);
555     }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)556     hardware::Return<V1_0::ErrorStatus> prepareModel(
557             const V1_0::Model& model,
558             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
559         return mLatestDriver->prepareModel(model, actualCallback);
560     }
561 
562    private:
563     const sp<V1_3::IDevice> mLatestDriver;
564 };
565 
566 // Like TestDriver, but implementing 1.0
567 class TestDriver10 : public V1_0::IDevice {
568    public:
TestDriver10(const std::string & name,V1_3::ErrorStatus errorStatus)569     TestDriver10(const std::string& name, V1_3::ErrorStatus errorStatus)
570         : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities(getCapabilities_cb _hidl_cb)571     hardware::Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
572         return mLatestDriver->getCapabilities(_hidl_cb);
573     }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)574     hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
575                                                   getSupportedOperations_cb _hidl_cb) override {
576         return mLatestDriver->getSupportedOperations(model, _hidl_cb);
577     }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)578     hardware::Return<V1_0::ErrorStatus> prepareModel(
579             const V1_0::Model& model,
580             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
581         return mLatestDriver->prepareModel(model, actualCallback);
582     }
getStatus()583     hardware::Return<V1_0::DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
584 
585    private:
586     const sp<V1_3::IDevice> mLatestDriver;
587 };
588 
589 // This class adds some simple utilities on top of WrapperCompilation in order
590 // to provide access to certain features from CompilationBuilder that are not
591 // exposed by the base class.
592 template <typename DriverClass>
593 class TestCompilation : public WrapperCompilation {
594    public:
595     // Allow dummying up the error status for all executions from this
596     // compilation.  If errorStatus is NONE, then execute behaves
597     // normally (and sends back the actual execution status).
598     // Otherwise, don't bother to execute, and just send back
599     // errorStatus (as the execution status, not the launch status).
TestCompilation(const WrapperModel * model,const std::string & deviceName,V1_3::ErrorStatus errorStatus)600     TestCompilation(const WrapperModel* model, const std::string& deviceName,
601                     V1_3::ErrorStatus errorStatus) {
602         std::vector<std::shared_ptr<Device>> devices;
603         auto device = DeviceManager::forTest_makeDriverDevice(
604                 nn::makeSharedDevice(deviceName, new DriverClass(deviceName, errorStatus)));
605         devices.push_back(device);
606 
607         nn::ModelBuilder* m = reinterpret_cast<nn::ModelBuilder*>(model->getHandle());
608         CompilationBuilder* c = nullptr;
609         int result = m->createCompilation(&c, devices);
610         EXPECT_EQ(result, 0);
611         // We need to ensure that we use our TestDriver and do not
612         // fall back to CPU.  (If we allow CPU fallback, then when our
613         // TestDriver reports an execution failure, we'll re-execute
614         // on CPU, and will not see the failure.)
615         c->forTest_setPartitioning(DeviceManager::kPartitioningWithoutFallback);
616         mCompilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
617     }
618 };
619 
620 // This class has roughly the same functionality as TestCompilation class.
621 // The major difference is that Introspection API is used to select the device.
622 class TestIntrospectionCompilation : public WrapperCompilation {
623    public:
TestIntrospectionCompilation(const WrapperModel * model,const std::string & deviceName)624     TestIntrospectionCompilation(const WrapperModel* model, const std::string& deviceName) {
625         std::vector<ANeuralNetworksDevice*> mDevices;
626         uint32_t numDevices = 0;
627         EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
628         EXPECT_GE(numDevices, (uint32_t)1);
629 
630         for (uint32_t i = 0; i < numDevices; i++) {
631             ANeuralNetworksDevice* device = nullptr;
632             EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
633             const char* buffer = nullptr;
634             int result = ANeuralNetworksDevice_getName(device, &buffer);
635             if (result == ANEURALNETWORKS_NO_ERROR && deviceName.compare(buffer) == 0) {
636                 mDevices.push_back(device);
637             }
638         }
639         // In CPU only mode, DeviceManager::getDrivers() will not be able to
640         // provide the actual device list. We will not be able to find the test
641         // driver with specified deviceName.
642         if (!DeviceManager::get()->getUseCpuOnly()) {
643             EXPECT_EQ(mDevices.size(), (uint32_t)1);
644 
645             int result = ANeuralNetworksCompilation_createForDevices(
646                     model->getHandle(), mDevices.data(), mDevices.size(), &mCompilation);
647             EXPECT_EQ(result, ANEURALNETWORKS_NO_ERROR);
648         }
649     }
650 };
651 
652 template <class DriverClass>
653 class ExecutionTestTemplate
654     : public ::testing::TestWithParam<std::tuple<V1_3::ErrorStatus, WrapperResult, bool>> {
655    public:
ExecutionTestTemplate()656     ExecutionTestTemplate()
657         : kName(toString(std::get<0>(GetParam()))),
658           kForceErrorStatus(std::get<0>(GetParam())),
659           kExpectResult(std::get<1>(GetParam())),
660           kUseIntrospectionAPI(std::get<2>(GetParam())),
661           mModel(makeModel()) {
662         if (kUseIntrospectionAPI) {
663             DeviceManager::get()->forTest_registerDevice(
664                     nn::makeSharedDevice(kName, new DriverClass(kName.c_str(), kForceErrorStatus)));
665             mCompilation = TestIntrospectionCompilation(&mModel, kName);
666         } else {
667             mCompilation = TestCompilation<DriverClass>(&mModel, kName, kForceErrorStatus);
668         }
669     }
670 
671    protected:
672     // Unit test method
673     // Set "reusable" to true to test reusable execution; Otherwise, test non-reusable execution.
674     void TestWait(bool reusable);
675 
TearDown()676     virtual void TearDown() {
677         // Reinitialize the device list since Introspection API path altered it.
678         if (kUseIntrospectionAPI) {
679             DeviceManager::get()->forTest_reInitializeDeviceList();
680         }
681     }
682 
getDimensionsWhileRunning(WrapperExecution & execution)683     void getDimensionsWhileRunning(WrapperExecution& execution) {
684         TestPreparedModelLatest::waitForExecutionToBegin();
685         // Cannot query dimensions while execution is running
686         std::vector<uint32_t> dimensions;
687         EXPECT_EQ(execution.getOutputOperandDimensions(0, &dimensions), WrapperResult::BAD_STATE);
688     }
689 
690     const std::string kName;
691 
692     // Allow dummying up the error status for execution.  If
693     // kForceErrorStatus is NONE, then execution behaves normally (and
694     // sends back the actual execution status).  Otherwise, don't
695     // bother to execute, and just send back kForceErrorStatus (as the
696     // execution status, not the launch status).
697     const V1_3::ErrorStatus kForceErrorStatus;
698 
699     // What result do we expect from the execution?  (The WrapperResult
700     // equivalent of kForceErrorStatus.)
701     const WrapperResult kExpectResult;
702 
703     // Whether mCompilation is created via Introspection API or not.
704     const bool kUseIntrospectionAPI;
705 
706     WrapperModel mModel;
707     WrapperCompilation mCompilation;
708 
setInputOutput(WrapperExecution * execution)709     void setInputOutput(WrapperExecution* execution) {
710         mInputBuffer = kInputBuffer;
711         mOutputBuffer = kOutputBufferInitial;
712         ASSERT_EQ(execution->setInput(0, &mInputBuffer, sizeof(mInputBuffer)),
713                   WrapperResult::NO_ERROR);
714         ASSERT_EQ(execution->setOutput(0, &mOutputBuffer, sizeof(mOutputBuffer)),
715                   WrapperResult::NO_ERROR);
716     }
717 
718     const float kInputBuffer = 3.14;
719     const float kOutputBufferInitial = 0;
720     float mInputBuffer;
721     float mOutputBuffer;
722     const float kOutputBufferExpected = 3;
723     const std::vector<uint32_t> kOutputDimensionsExpected = {1};
724 
725    private:
makeModel()726     static WrapperModel makeModel() {
727         static const WrapperOperandType tensorType(WrapperType::TENSOR_FLOAT32, {1});
728 
729         WrapperModel model;
730         uint32_t input = model.addOperand(&tensorType);
731         uint32_t output = model.addOperand(&tensorType);
732         model.addOperation(ANEURALNETWORKS_FLOOR, {input}, {output});
733         model.identifyInputsAndOutputs({input}, {output});
734         assert(model.finish() == WrapperResult::NO_ERROR);
735 
736         return model;
737     }
738 };
739 
computeHelper(bool reusable,const std::function<void ()> & compute)740 void computeHelper(bool reusable, const std::function<void()>& compute) {
741     {
742         SCOPED_TRACE(reusable ? "first time reusable" : "non-reusable");
743         compute();
744     }
745     if (reusable) {
746         SCOPED_TRACE("second time reusable");
747         compute();
748     }
749 }
750 
751 template <class DriverClass>
TestWait(bool reusable)752 void ExecutionTestTemplate<DriverClass>::TestWait(bool reusable) {
753     SCOPED_TRACE(kName);
754     // Skip Introspection API tests when CPU only flag is forced on.
755     if (kUseIntrospectionAPI && DeviceManager::get()->getUseCpuOnly()) {
756         GTEST_SKIP();
757     }
758 
759     ASSERT_EQ(mCompilation.finish(), WrapperResult::NO_ERROR);
760 
761     {
762         SCOPED_TRACE("startCompute");
763         WrapperExecution execution(&mCompilation);
764         ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
765         ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
766         const auto compute = [this, &execution] {
767             TestPreparedModelLatest::pauseExecutions(true);
768             WrapperEvent event;
769             ASSERT_EQ(execution.startCompute(&event), WrapperResult::NO_ERROR);
770             getDimensionsWhileRunning(execution);
771             TestPreparedModelLatest::pauseExecutions(false);
772             ASSERT_EQ(event.wait(), kExpectResult);
773             if (kExpectResult == WrapperResult::NO_ERROR) {
774                 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
775             }
776             std::vector<uint32_t> dimensions;
777             if (kExpectResult == WrapperResult::NO_ERROR ||
778                 kExpectResult == WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
779                 // Only one output operand, hardcoded as index 0.
780                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
781                 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
782             } else {
783                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
784                           WrapperResult::BAD_STATE);
785             }
786         };
787         computeHelper(reusable, compute);
788     }
789     {
790         SCOPED_TRACE("compute");
791         WrapperExecution execution(&mCompilation);
792         ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
793         ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
794         const auto compute = [this, &execution] {
795             TestPreparedModelLatest::pauseExecutions(true);
796             std::thread run([this, &execution] { EXPECT_EQ(execution.compute(), kExpectResult); });
797             getDimensionsWhileRunning(execution);
798             TestPreparedModelLatest::pauseExecutions(false);
799             run.join();
800             if (kExpectResult == WrapperResult::NO_ERROR) {
801                 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
802             }
803             std::vector<uint32_t> dimensions;
804             if (kExpectResult == WrapperResult::NO_ERROR ||
805                 kExpectResult == WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
806                 // Only one output operand, hardcoded as index 0.
807                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
808                 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
809             } else {
810                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
811                           WrapperResult::BAD_STATE);
812             }
813         };
814         computeHelper(reusable, compute);
815     }
816     {
817         SCOPED_TRACE("burstCompute");
818 
819         // TODO: If a burst API is added to nn::test_wrapper (e.g.,
820         // Execution::burstCompute()), then use that, rather than
821         // Execution::compute(WrapperExecution::ComputeMode::BURST).
822 
823         WrapperExecution execution(&mCompilation);
824         ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
825         ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
826         const auto compute = [this, &execution] {
827             TestPreparedModelLatest::pauseExecutions(true);
828             std::thread run([this, &execution] {
829                 EXPECT_EQ(execution.compute(WrapperExecution::ComputeMode::BURST), kExpectResult);
830             });
831             getDimensionsWhileRunning(execution);
832             TestPreparedModelLatest::pauseExecutions(false);
833             run.join();
834             if (kExpectResult == WrapperResult::NO_ERROR) {
835                 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
836             }
837             std::vector<uint32_t> dimensions;
838             if (kExpectResult == WrapperResult::NO_ERROR ||
839                 kExpectResult == WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
840                 // Only one output operand, hardcoded as index 0.
841                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
842                 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
843             } else {
844                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
845                           WrapperResult::BAD_STATE);
846             }
847         };
848         computeHelper(reusable, compute);
849     }
850     if (kExpectResult != WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
851         // computeWithDependencies doesn't support OUTPUT_INSUFFICIENT_SIZE
852         SCOPED_TRACE("computeWithDependencies");
853         WrapperExecution execution(&mCompilation);
854         ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
855         ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
856 
857         const auto compute = [this, &execution] {
858             TestPreparedModelLatest::pauseExecutions(true);
859 
860             WrapperEvent event;
861             // Note, due to the limitation of SampleDriver implementation, the call is synchronous.
862             // If the SampleDriver is updated to return real sync fence, this must be updated.
863             std::thread run([this, &execution, &event] {
864                 EXPECT_EQ(execution.startComputeWithDependencies({}, 0, &event), kExpectResult);
865             });
866             getDimensionsWhileRunning(execution);
867             TestPreparedModelLatest::pauseExecutions(false);
868             run.join();
869             if (kExpectResult == WrapperResult::NO_ERROR) {
870                 ASSERT_EQ(event.wait(), kExpectResult);
871                 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
872             } else {
873                 ASSERT_EQ(event.wait(), WrapperResult::UNEXPECTED_NULL);
874             }
875             std::vector<uint32_t> dimensions;
876             if (kExpectResult == WrapperResult::NO_ERROR) {
877                 // Only one output operand, hardcoded as index 0.
878                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
879                 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
880             } else {
881                 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
882                           WrapperResult::BAD_STATE);
883             }
884         };
885         computeHelper(reusable, compute);
886     }
887 }
888 
889 auto kTestValues = ::testing::Values(
890         std::make_tuple(V1_3::ErrorStatus::NONE, WrapperResult::NO_ERROR,
891                         /* kUseIntrospectionAPI */ false),
892         std::make_tuple(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, WrapperResult::UNAVAILABLE_DEVICE,
893                         /* kUseIntrospectionAPI */ false),
894         std::make_tuple(V1_3::ErrorStatus::GENERAL_FAILURE, WrapperResult::OP_FAILED,
895                         /* kUseIntrospectionAPI */ false),
896         std::make_tuple(V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE,
897                         WrapperResult::OUTPUT_INSUFFICIENT_SIZE,
898                         /* kUseIntrospectionAPI */ false),
899         std::make_tuple(V1_3::ErrorStatus::INVALID_ARGUMENT, WrapperResult::BAD_DATA,
900                         /* kUseIntrospectionAPI */ false));
901 
902 class ExecutionTest13 : public ExecutionTestTemplate<TestDriver13> {};
TEST_P(ExecutionTest13,Wait)903 TEST_P(ExecutionTest13, Wait) {
904     TestWait(/*reusable=*/false);
905 }
TEST_P(ExecutionTest13,WaitReusable)906 TEST_P(ExecutionTest13, WaitReusable) {
907     TestWait(/*reusable=*/true);
908 }
909 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest13, kTestValues);
910 
911 class ExecutionTest12 : public ExecutionTestTemplate<TestDriver12> {};
TEST_P(ExecutionTest12,Wait)912 TEST_P(ExecutionTest12, Wait) {
913     TestWait(/*reusable=*/false);
914 }
TEST_P(ExecutionTest12,WaitReusable)915 TEST_P(ExecutionTest12, WaitReusable) {
916     TestWait(/*reusable=*/true);
917 }
918 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest12, kTestValues);
919 
920 class ExecutionTest11 : public ExecutionTestTemplate<TestDriver11> {};
TEST_P(ExecutionTest11,Wait)921 TEST_P(ExecutionTest11, Wait) {
922     if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
923     TestWait(/*reusable=*/false);
924 }
TEST_P(ExecutionTest11,WaitReusable)925 TEST_P(ExecutionTest11, WaitReusable) {
926     if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
927     TestWait(/*reusable=*/true);
928 }
929 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest11, kTestValues);
930 
931 class ExecutionTest10 : public ExecutionTestTemplate<TestDriver10> {};
TEST_P(ExecutionTest10,Wait)932 TEST_P(ExecutionTest10, Wait) {
933     if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
934     TestWait(/*reusable=*/false);
935 }
TEST_P(ExecutionTest10,WaitReusable)936 TEST_P(ExecutionTest10, WaitReusable) {
937     if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
938     TestWait(/*reusable=*/true);
939 }
940 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest10, kTestValues);
941 
942 auto kIntrospectionTestValues = ::testing::Values(
943         std::make_tuple(V1_3::ErrorStatus::NONE, WrapperResult::NO_ERROR,
944                         /* kUseIntrospectionAPI */ true),
945         std::make_tuple(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, WrapperResult::UNAVAILABLE_DEVICE,
946                         /* kUseIntrospectionAPI */ true),
947         std::make_tuple(V1_3::ErrorStatus::GENERAL_FAILURE, WrapperResult::OP_FAILED,
948                         /* kUseIntrospectionAPI */ true),
949         std::make_tuple(V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE,
950                         WrapperResult::OUTPUT_INSUFFICIENT_SIZE,
951                         /* kUseIntrospectionAPI */ true),
952         std::make_tuple(V1_3::ErrorStatus::INVALID_ARGUMENT, WrapperResult::BAD_DATA,
953                         /* kUseIntrospectionAPI */ true));
954 
955 INSTANTIATE_TEST_SUITE_P(IntrospectionFlavor, ExecutionTest13, kIntrospectionTestValues);
956 
957 }  // namespace
958 }  // namespace android
959