1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <HalInterfaces.h>
18 #include <SampleDriver.h>
19 #include <ValidateHal.h>
20 #include <gtest/gtest.h>
21
22 #include <algorithm>
23 #include <atomic>
24 #include <cassert>
25 #include <functional>
26 #include <memory>
27 #include <string>
28 #include <thread>
29 #include <tuple>
30 #include <vector>
31
32 #include "CompilationBuilder.h"
33 #include "ExecutionBurstServer.h"
34 #include "ExecutionCallback.h"
35 #include "HalUtils.h"
36 #include "Manager.h"
37 #include "ModelBuilder.h"
38 #include "NeuralNetworks.h"
39 #include "PreparedModelCallback.h"
40 #include "TestNeuralNetworksWrapper.h"
41
42 namespace android {
43
44 namespace V1_0 = ::android::hardware::neuralnetworks::V1_0;
45 namespace V1_1 = ::android::hardware::neuralnetworks::V1_1;
46 namespace V1_2 = ::android::hardware::neuralnetworks::V1_2;
47 namespace V1_3 = ::android::hardware::neuralnetworks::V1_3;
48 using CompilationBuilder = nn::CompilationBuilder;
49 using Device = nn::Device;
50 using SharedDevice = nn::SharedDevice;
51 using DeviceManager = nn::DeviceManager;
52 using HidlModel = V1_3::Model;
53 using PreparedModelCallback = nn::PreparedModelCallback;
54 using SampleDriver = nn::sample_driver::SampleDriver;
55 using WrapperCompilation = nn::test_wrapper::Compilation;
56 using WrapperEvent = nn::test_wrapper::Event;
57 using WrapperExecution = nn::test_wrapper::Execution;
58 using WrapperModel = nn::test_wrapper::Model;
59 using WrapperOperandType = nn::test_wrapper::OperandType;
60 using WrapperResult = nn::test_wrapper::Result;
61 using WrapperType = nn::test_wrapper::Type;
62 using nn::convertToV1_0;
63 using nn::convertToV1_3;
64 using nn::ErrorStatus;
65
66 template <typename T>
67 using MQDescriptorSync = hardware::MQDescriptorSync<T>;
68
69 namespace {
70
71 const V1_2::Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
72
73 // Wraps the latest version of IPreparedModel to allow dummying up the execution status,
74 // and control when the execution finishes.
75 class TestPreparedModelLatest : public V1_3::IPreparedModel {
76 public:
77 // If errorStatus is NONE, then execute behaves normally (and sends back
78 // the actual execution status). Otherwise, don't bother to execute, and
79 // just send back errorStatus (as the execution status, not the launch
80 // status).
TestPreparedModelLatest(sp<V1_0::IPreparedModel> preparedModel,V1_3::ErrorStatus errorStatus)81 TestPreparedModelLatest(sp<V1_0::IPreparedModel> preparedModel, V1_3::ErrorStatus errorStatus)
82 : mPreparedModelV1_0(preparedModel),
83 mPreparedModelV1_2(V1_2::IPreparedModel::castFrom(preparedModel).withDefault(nullptr)),
84 mPreparedModelV1_3(V1_3::IPreparedModel::castFrom(preparedModel).withDefault(nullptr)),
85 mErrorStatus(errorStatus) {}
86
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)87 hardware::Return<V1_0::ErrorStatus> execute(
88 const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override {
89 CHECK(mPreparedModelV1_0 != nullptr) << "V1_0 prepared model is nullptr.";
90 std::thread([this, request, callback] {
91 dummyExecution();
92 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
93 // Note that we lose the actual launch status.
94 (void)mPreparedModelV1_0->execute(request, callback);
95 } else {
96 callback->notify(convertToV1_0(mErrorStatus));
97 }
98 }).detach();
99 return V1_0::ErrorStatus::NONE;
100 }
101
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)102 hardware::Return<V1_0::ErrorStatus> execute_1_2(
103 const V1_0::Request& request, V1_2::MeasureTiming measure,
104 const sp<V1_2::IExecutionCallback>& callback) override {
105 CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
106 std::thread([this, request, measure, callback] {
107 dummyExecution();
108 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
109 // Note that we lose the actual launch status.
110 (void)mPreparedModelV1_2->execute_1_2(request, measure, callback);
111 } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
112 V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
113 callback->notify_1_2(convertToV1_0(mErrorStatus), {shape}, kBadTiming);
114 } else {
115 callback->notify_1_2(convertToV1_0(mErrorStatus), {}, kBadTiming);
116 }
117 }).detach();
118 return V1_0::ErrorStatus::NONE;
119 }
120
execute_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)121 hardware::Return<V1_3::ErrorStatus> execute_1_3(
122 const V1_3::Request& request, V1_2::MeasureTiming measure,
123 const V1_3::OptionalTimePoint& deadline,
124 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
125 const sp<V1_3::IExecutionCallback>& callback) override {
126 CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
127 std::thread([this, request, measure, deadline, loopTimeoutDuration, callback] {
128 dummyExecution();
129 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
130 // Note that we lose the actual launch status.
131 (void)mPreparedModelV1_3->execute_1_3(request, measure, deadline,
132 loopTimeoutDuration, callback);
133 } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
134 V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
135 callback->notify_1_3(mErrorStatus, {shape}, kBadTiming);
136 } else {
137 callback->notify_1_3(mErrorStatus, {}, kBadTiming);
138 }
139 }).detach();
140 return V1_3::ErrorStatus::NONE;
141 }
142
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)143 hardware::Return<void> executeSynchronously(const V1_0::Request& request,
144 V1_2::MeasureTiming measure,
145 executeSynchronously_cb cb) override {
146 CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
147 dummyExecution();
148 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
149 return mPreparedModelV1_2->executeSynchronously(request, measure, cb);
150 } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
151 V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
152 cb(convertToV1_0(mErrorStatus), {shape}, kBadTiming);
153 return hardware::Void();
154 } else {
155 cb(convertToV1_0(mErrorStatus), {}, kBadTiming);
156 return hardware::Void();
157 }
158 }
159
executeSynchronously_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)160 hardware::Return<void> executeSynchronously_1_3(
161 const V1_3::Request& request, V1_2::MeasureTiming measure,
162 const V1_3::OptionalTimePoint& deadline,
163 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
164 executeSynchronously_1_3_cb cb) override {
165 CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
166 dummyExecution();
167 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
168 return mPreparedModelV1_3->executeSynchronously_1_3(request, measure, deadline,
169 loopTimeoutDuration, cb);
170 } else if (mErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
171 V1_2::OutputShape shape = {.dimensions = {1}, .isSufficient = false};
172 cb(mErrorStatus, {shape}, kBadTiming);
173 return hardware::Void();
174 } else {
175 cb(mErrorStatus, {}, kBadTiming);
176 return hardware::Void();
177 }
178 }
179
180 // ExecutionBurstServer::create has an overload that will use
181 // IPreparedModel::executeSynchronously(), so we can rely on that, rather
182 // than having to implement ExecutionBurstServer::IExecutorWithCache.
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)183 hardware::Return<void> configureExecutionBurst(
184 const sp<V1_2::IBurstCallback>& callback,
185 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
186 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
187 configureExecutionBurst_cb cb) override {
188 CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
189 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
190 const sp<V1_2::IBurstContext> burst =
191 nn::ExecutionBurstServer::create(callback, requestChannel, resultChannel, this);
192
193 cb(burst == nullptr ? V1_0::ErrorStatus::GENERAL_FAILURE : V1_0::ErrorStatus::NONE,
194 burst);
195 return hardware::Void();
196 } else {
197 cb(convertToV1_0(mErrorStatus), nullptr);
198 return hardware::Void();
199 }
200 }
201
202 // Note, due to the limitation of SampleDriver implementation, the call is
203 // synchronous. The test code that exercises this implementation of
204 // SampleDriver is written with that in mind. Therefore, this
205 // implementation is synchronous also. If the SampleDriver is updated to
206 // return real sync fence, this must be updated.
executeFenced(const V1_3::Request & request,const hardware::hidl_vec<hardware::hidl_handle> & waitFor,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const V1_3::OptionalTimeoutDuration & duration,executeFenced_cb cb)207 hardware::Return<void> executeFenced(const V1_3::Request& request,
208 const hardware::hidl_vec<hardware::hidl_handle>& waitFor,
209 V1_2::MeasureTiming measure,
210 const V1_3::OptionalTimePoint& deadline,
211 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
212 const V1_3::OptionalTimeoutDuration& duration,
213 executeFenced_cb cb) override {
214 CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
215 CHECK(mErrorStatus != V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE)
216 << "executeFenced does not support dynamic output shape";
217 dummyExecution();
218 if (mErrorStatus == V1_3::ErrorStatus::NONE) {
219 return mPreparedModelV1_3->executeFenced(request, waitFor, measure, deadline,
220 loopTimeoutDuration, duration, cb);
221 } else {
222 // Due to the limitations of the SampleDriver, all failures look
223 // like launch failures. If the SampleDriver is updated to return
224 // real sync fences, this must be updated.
225 cb(mErrorStatus, hardware::hidl_handle(nullptr), nullptr);
226 }
227 return hardware::Void();
228 }
229
230 // We can place the TestPreparedModelLatest system in a "pause" mode where
231 // no execution will complete until the system is taken out of that mode.
232 // Initially, the system is not in that mode.
pauseExecutions(bool v)233 static void pauseExecutions(bool v) { mPauseExecutions.store(v); }
234
235 // This function is only guaranteed to work in the following pattern:
236 // Consider thread A as primary thread
237 // - thread A: pauseExecutions(true);
238 // - thread A: launch execution (as thread B)
239 // - thread A: waitForExecutionToBegin(), block until call to dummyExecution by
240 // thread B makes mExecutionsInFlight nonzero
241 // - thread B: dummyExecution(), which makes mExecutionsInFlight nonzero and blocks
242 // until thread A calls pauseExecutions(false)
243 // - thread A: waitForExecutionToBegin() returns
244 // - thread A: pauseExecutions(false), allowing dummyExecution() on thread B to continue
245 // - thread B: dummyExecution() zeroes mExecutionsInFlight and returns
246 // - thread B: thread exits
waitForExecutionToBegin()247 static void waitForExecutionToBegin() {
248 CHECK(mPauseExecutions.load());
249 while (mExecutionsInFlight.load() == 0) {
250 }
251 }
252
253 private:
254 const sp<V1_0::IPreparedModel> mPreparedModelV1_0;
255 const sp<V1_2::IPreparedModel> mPreparedModelV1_2;
256 const sp<V1_3::IPreparedModel> mPreparedModelV1_3;
257 V1_3::ErrorStatus mErrorStatus;
258
259 static std::atomic<bool> mPauseExecutions;
260 static std::atomic<unsigned int> mExecutionsInFlight;
261
dummyExecution()262 static void dummyExecution() {
263 CHECK_EQ(mExecutionsInFlight.fetch_add(1), 0u) << "We do not support concurrent executions";
264 while (mPauseExecutions.load()) {
265 }
266 mExecutionsInFlight.fetch_sub(1);
267 }
268 };
269 std::atomic<bool> TestPreparedModelLatest::mPauseExecutions = false;
270 std::atomic<unsigned int> TestPreparedModelLatest::mExecutionsInFlight = 0;
271
272 using TestPreparedModel13 = TestPreparedModelLatest;
273
274 // Like TestPreparedModelLatest, but implementing 1.2
275 class TestPreparedModel12 : public V1_2::IPreparedModel {
276 public:
TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel,V1_3::ErrorStatus errorStatus)277 TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel, V1_3::ErrorStatus errorStatus)
278 : mLatestPreparedModel(new TestPreparedModelLatest(preparedModel, errorStatus)) {}
279
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)280 hardware::Return<V1_0::ErrorStatus> execute(
281 const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override {
282 return mLatestPreparedModel->execute(request, callback);
283 }
284
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)285 hardware::Return<V1_0::ErrorStatus> execute_1_2(
286 const V1_0::Request& request, V1_2::MeasureTiming measure,
287 const sp<V1_2::IExecutionCallback>& callback) override {
288 return mLatestPreparedModel->execute_1_2(request, measure, callback);
289 }
290
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)291 hardware::Return<void> executeSynchronously(const V1_0::Request& request,
292 V1_2::MeasureTiming measure,
293 executeSynchronously_cb cb) override {
294 return mLatestPreparedModel->executeSynchronously(request, measure, cb);
295 }
296
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)297 hardware::Return<void> configureExecutionBurst(
298 const sp<V1_2::IBurstCallback>& callback,
299 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
300 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
301 configureExecutionBurst_cb cb) override {
302 return mLatestPreparedModel->configureExecutionBurst(callback, requestChannel,
303 resultChannel, cb);
304 }
305
306 private:
307 const sp<V1_3::IPreparedModel> mLatestPreparedModel;
308 };
309
310 // Like TestPreparedModelLatest, but implementing 1.0
311 class TestPreparedModel10 : public V1_0::IPreparedModel {
312 public:
TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel,V1_3::ErrorStatus errorStatus)313 TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel, V1_3::ErrorStatus errorStatus)
314 : mLatestPreparedModel(new TestPreparedModelLatest(preparedModel, errorStatus)) {}
315
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)316 hardware::Return<V1_0::ErrorStatus> execute(
317 const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override {
318 return mLatestPreparedModel->execute(request, callback);
319 }
320
321 private:
322 const sp<V1_3::IPreparedModel> mLatestPreparedModel;
323 };
324
325 // Behaves like SampleDriver, except that it produces wrapped IPreparedModel.
326 class TestDriver13 : public SampleDriver {
327 public:
328 // Allow dummying up the error status for execution of all models
329 // prepared from this driver. If errorStatus is NONE, then
330 // execute behaves normally (and sends back the actual execution
331 // status). Otherwise, don't bother to execute, and just send
332 // back errorStatus (as the execution status, not the launch
333 // status).
TestDriver13(const std::string & name,V1_3::ErrorStatus errorStatus)334 TestDriver13(const std::string& name, V1_3::ErrorStatus errorStatus)
335 : SampleDriver(name.c_str()), mErrorStatus(errorStatus) {}
336
getCapabilities_1_3(getCapabilities_1_3_cb _hidl_cb)337 hardware::Return<void> getCapabilities_1_3(getCapabilities_1_3_cb _hidl_cb) override {
338 android::nn::initVLogMask();
339 const V1_0::PerformanceInfo kPerf = {.execTime = 0.75f, .powerUsage = 0.75f};
340 V1_3::Capabilities capabilities = {
341 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
342 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
343 .operandPerformance =
344 nn::nonExtensionOperandPerformance<nn::HalVersion::V1_3>(kPerf),
345 .ifPerformance = kPerf,
346 .whilePerformance = kPerf};
347 _hidl_cb(V1_3::ErrorStatus::NONE, capabilities);
348 return hardware::Void();
349 }
350
getSupportedOperations_1_3(const HidlModel & model,getSupportedOperations_1_3_cb cb)351 hardware::Return<void> getSupportedOperations_1_3(const HidlModel& model,
352 getSupportedOperations_1_3_cb cb) override {
353 if (nn::validateModel(model)) {
354 std::vector<bool> supported(model.main.operations.size(), true);
355 cb(V1_3::ErrorStatus::NONE, supported);
356 } else {
357 cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {});
358 }
359 return hardware::Void();
360 }
361
prepareModel_1_3(const HidlModel & model,V1_1::ExecutionPreference preference,V1_3::Priority priority,const V1_3::OptionalTimePoint & deadline,const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_3::IPreparedModelCallback> & actualCallback)362 hardware::Return<V1_3::ErrorStatus> prepareModel_1_3(
363 const HidlModel& model, V1_1::ExecutionPreference preference, V1_3::Priority priority,
364 const V1_3::OptionalTimePoint& deadline,
365 const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
366 const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
367 const nn::HalCacheToken& token,
368 const sp<V1_3::IPreparedModelCallback>& actualCallback) override {
369 sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
370 hardware::Return<V1_3::ErrorStatus> prepareModelReturn = SampleDriver::prepareModel_1_3(
371 model, preference, priority, deadline, modelCache, dataCache, token, localCallback);
372 if (!prepareModelReturn.isOkUnchecked()) {
373 return prepareModelReturn;
374 }
375 if (prepareModelReturn != V1_3::ErrorStatus::NONE) {
376 actualCallback->notify_1_3(
377 convertToV1_3(localCallback->getStatus()),
378 V1_3::IPreparedModel::castFrom(localCallback->getPreparedModel()));
379 return prepareModelReturn;
380 }
381 localCallback->wait();
382 if (localCallback->getStatus() != ErrorStatus::NONE) {
383 actualCallback->notify_1_3(
384 convertToV1_3(localCallback->getStatus()),
385 V1_3::IPreparedModel::castFrom(localCallback->getPreparedModel()));
386 } else {
387 actualCallback->notify_1_3(
388 V1_3::ErrorStatus::NONE,
389 new TestPreparedModel13(localCallback->getPreparedModel(), mErrorStatus));
390 }
391 return prepareModelReturn;
392 }
393
prepareModel_1_2(const V1_2::Model & model,V1_1::ExecutionPreference preference,const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_2::IPreparedModelCallback> & actualCallback)394 hardware::Return<V1_0::ErrorStatus> prepareModel_1_2(
395 const V1_2::Model& model, V1_1::ExecutionPreference preference,
396 const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
397 const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
398 const nn::HalCacheToken& token,
399 const sp<V1_2::IPreparedModelCallback>& actualCallback) override {
400 sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
401 hardware::Return<V1_0::ErrorStatus> prepareModelReturn = SampleDriver::prepareModel_1_2(
402 model, preference, modelCache, dataCache, token, localCallback);
403 if (!prepareModelReturn.isOkUnchecked()) {
404 return prepareModelReturn;
405 }
406 if (prepareModelReturn != V1_0::ErrorStatus::NONE) {
407 actualCallback->notify_1_2(
408 convertToV1_0(localCallback->getStatus()),
409 V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
410 return prepareModelReturn;
411 }
412 localCallback->wait();
413 if (localCallback->getStatus() != ErrorStatus::NONE) {
414 actualCallback->notify_1_2(
415 convertToV1_0(localCallback->getStatus()),
416 V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
417 } else {
418 actualCallback->notify_1_2(
419 V1_0::ErrorStatus::NONE,
420 new TestPreparedModel12(localCallback->getPreparedModel(), mErrorStatus));
421 }
422 return prepareModelReturn;
423 }
424
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)425 hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
426 const V1_1::Model& model, V1_1::ExecutionPreference preference,
427 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
428 sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
429 hardware::Return<V1_0::ErrorStatus> prepareModelReturn =
430 SampleDriver::prepareModel_1_1(model, preference, localCallback);
431 if (!prepareModelReturn.isOkUnchecked()) {
432 return prepareModelReturn;
433 }
434 if (prepareModelReturn != V1_0::ErrorStatus::NONE) {
435 actualCallback->notify(convertToV1_0(localCallback->getStatus()),
436 localCallback->getPreparedModel());
437 return prepareModelReturn;
438 }
439 localCallback->wait();
440 if (localCallback->getStatus() != ErrorStatus::NONE) {
441 actualCallback->notify(convertToV1_0(localCallback->getStatus()),
442 localCallback->getPreparedModel());
443 } else {
444 actualCallback->notify(
445 V1_0::ErrorStatus::NONE,
446 new TestPreparedModel10(localCallback->getPreparedModel(), mErrorStatus));
447 }
448 return prepareModelReturn;
449 }
450
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)451 hardware::Return<V1_0::ErrorStatus> prepareModel(
452 const V1_0::Model& model,
453 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
454 return prepareModel_1_1(nn::convertToV1_1(model),
455 V1_1::ExecutionPreference::FAST_SINGLE_ANSWER, actualCallback);
456 }
457
458 private:
459 V1_3::ErrorStatus mErrorStatus;
460 };
461
462 // Like TestDriver, but implementing 1.2
463 class TestDriver12 : public V1_2::IDevice {
464 public:
TestDriver12(const std::string & name,V1_3::ErrorStatus errorStatus)465 TestDriver12(const std::string& name, V1_3::ErrorStatus errorStatus)
466 : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb)467 hardware::Return<void> getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb) override {
468 return mLatestDriver->getCapabilities_1_2(_hidl_cb);
469 }
getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb)470 hardware::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb) override {
471 return mLatestDriver->getCapabilities_1_1(_hidl_cb);
472 }
getCapabilities(getCapabilities_cb _hidl_cb)473 hardware::Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
474 return mLatestDriver->getCapabilities(_hidl_cb);
475 }
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb _hidl_cb)476 hardware::Return<void> getSupportedOperations_1_2(
477 const V1_2::Model& model, getSupportedOperations_1_2_cb _hidl_cb) override {
478 return mLatestDriver->getSupportedOperations_1_2(model, _hidl_cb);
479 }
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb _hidl_cb)480 hardware::Return<void> getSupportedOperations_1_1(
481 const V1_1::Model& model, getSupportedOperations_1_1_cb _hidl_cb) override {
482 return mLatestDriver->getSupportedOperations_1_1(model, _hidl_cb);
483 }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)484 hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
485 getSupportedOperations_cb _hidl_cb) override {
486 return mLatestDriver->getSupportedOperations(model, _hidl_cb);
487 }
prepareModel_1_2(const V1_2::Model & model,V1_1::ExecutionPreference preference,const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_2::IPreparedModelCallback> & actualCallback)488 hardware::Return<V1_0::ErrorStatus> prepareModel_1_2(
489 const V1_2::Model& model, V1_1::ExecutionPreference preference,
490 const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
491 const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
492 const nn::HalCacheToken& token,
493 const sp<V1_2::IPreparedModelCallback>& actualCallback) override {
494 return mLatestDriver->prepareModel_1_2(model, preference, modelCache, dataCache, token,
495 actualCallback);
496 }
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)497 hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
498 const V1_1::Model& model, V1_1::ExecutionPreference preference,
499 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
500 return mLatestDriver->prepareModel_1_1(model, preference, actualCallback);
501 }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)502 hardware::Return<V1_0::ErrorStatus> prepareModel(
503 const V1_0::Model& model,
504 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
505 return mLatestDriver->prepareModel(model, actualCallback);
506 }
getStatus()507 hardware::Return<V1_0::DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
getVersionString(getVersionString_cb _hidl_cb)508 hardware::Return<void> getVersionString(getVersionString_cb _hidl_cb) override {
509 return mLatestDriver->getVersionString(_hidl_cb);
510 }
getType(getType_cb _hidl_cb)511 hardware::Return<void> getType(getType_cb _hidl_cb) override {
512 return mLatestDriver->getType(_hidl_cb);
513 }
getSupportedExtensions(getSupportedExtensions_cb _hidl_cb)514 hardware::Return<void> getSupportedExtensions(getSupportedExtensions_cb _hidl_cb) {
515 return mLatestDriver->getSupportedExtensions(_hidl_cb);
516 }
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb _hidl_cb)517 hardware::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb _hidl_cb) {
518 return mLatestDriver->getNumberOfCacheFilesNeeded(_hidl_cb);
519 }
prepareModelFromCache(const hardware::hidl_vec<hardware::hidl_handle> & modelCache,const hardware::hidl_vec<hardware::hidl_handle> & dataCache,const nn::HalCacheToken & token,const sp<V1_2::IPreparedModelCallback> & callback)520 hardware::Return<V1_0::ErrorStatus> prepareModelFromCache(
521 const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
522 const hardware::hidl_vec<hardware::hidl_handle>& dataCache,
523 const nn::HalCacheToken& token, const sp<V1_2::IPreparedModelCallback>& callback) {
524 return mLatestDriver->prepareModelFromCache(modelCache, dataCache, token, callback);
525 }
526
527 private:
528 const sp<V1_3::IDevice> mLatestDriver;
529 };
530
531 // Like TestDriver, but implementing 1.1
532 class TestDriver11 : public V1_1::IDevice {
533 public:
TestDriver11(const std::string & name,V1_3::ErrorStatus errorStatus)534 TestDriver11(const std::string& name, V1_3::ErrorStatus errorStatus)
535 : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb)536 hardware::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb) override {
537 return mLatestDriver->getCapabilities_1_1(_hidl_cb);
538 }
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb _hidl_cb)539 hardware::Return<void> getSupportedOperations_1_1(
540 const V1_1::Model& model, getSupportedOperations_1_1_cb _hidl_cb) override {
541 return mLatestDriver->getSupportedOperations_1_1(model, _hidl_cb);
542 }
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)543 hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
544 const V1_1::Model& model, V1_1::ExecutionPreference preference,
545 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
546 return mLatestDriver->prepareModel_1_1(model, preference, actualCallback);
547 }
getStatus()548 hardware::Return<V1_0::DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
getCapabilities(getCapabilities_cb _hidl_cb)549 hardware::Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
550 return mLatestDriver->getCapabilities(_hidl_cb);
551 }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)552 hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
553 getSupportedOperations_cb _hidl_cb) override {
554 return mLatestDriver->getSupportedOperations(model, _hidl_cb);
555 }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)556 hardware::Return<V1_0::ErrorStatus> prepareModel(
557 const V1_0::Model& model,
558 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
559 return mLatestDriver->prepareModel(model, actualCallback);
560 }
561
562 private:
563 const sp<V1_3::IDevice> mLatestDriver;
564 };
565
566 // Like TestDriver, but implementing 1.0
567 class TestDriver10 : public V1_0::IDevice {
568 public:
TestDriver10(const std::string & name,V1_3::ErrorStatus errorStatus)569 TestDriver10(const std::string& name, V1_3::ErrorStatus errorStatus)
570 : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities(getCapabilities_cb _hidl_cb)571 hardware::Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
572 return mLatestDriver->getCapabilities(_hidl_cb);
573 }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)574 hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
575 getSupportedOperations_cb _hidl_cb) override {
576 return mLatestDriver->getSupportedOperations(model, _hidl_cb);
577 }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)578 hardware::Return<V1_0::ErrorStatus> prepareModel(
579 const V1_0::Model& model,
580 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
581 return mLatestDriver->prepareModel(model, actualCallback);
582 }
getStatus()583 hardware::Return<V1_0::DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
584
585 private:
586 const sp<V1_3::IDevice> mLatestDriver;
587 };
588
589 // This class adds some simple utilities on top of WrapperCompilation in order
590 // to provide access to certain features from CompilationBuilder that are not
591 // exposed by the base class.
592 template <typename DriverClass>
593 class TestCompilation : public WrapperCompilation {
594 public:
595 // Allow dummying up the error status for all executions from this
596 // compilation. If errorStatus is NONE, then execute behaves
597 // normally (and sends back the actual execution status).
598 // Otherwise, don't bother to execute, and just send back
599 // errorStatus (as the execution status, not the launch status).
TestCompilation(const WrapperModel * model,const std::string & deviceName,V1_3::ErrorStatus errorStatus)600 TestCompilation(const WrapperModel* model, const std::string& deviceName,
601 V1_3::ErrorStatus errorStatus) {
602 std::vector<std::shared_ptr<Device>> devices;
603 auto device = DeviceManager::forTest_makeDriverDevice(
604 nn::makeSharedDevice(deviceName, new DriverClass(deviceName, errorStatus)));
605 devices.push_back(device);
606
607 nn::ModelBuilder* m = reinterpret_cast<nn::ModelBuilder*>(model->getHandle());
608 CompilationBuilder* c = nullptr;
609 int result = m->createCompilation(&c, devices);
610 EXPECT_EQ(result, 0);
611 // We need to ensure that we use our TestDriver and do not
612 // fall back to CPU. (If we allow CPU fallback, then when our
613 // TestDriver reports an execution failure, we'll re-execute
614 // on CPU, and will not see the failure.)
615 c->forTest_setPartitioning(DeviceManager::kPartitioningWithoutFallback);
616 mCompilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
617 }
618 };
619
620 // This class has roughly the same functionality as TestCompilation class.
621 // The major difference is that Introspection API is used to select the device.
622 class TestIntrospectionCompilation : public WrapperCompilation {
623 public:
TestIntrospectionCompilation(const WrapperModel * model,const std::string & deviceName)624 TestIntrospectionCompilation(const WrapperModel* model, const std::string& deviceName) {
625 std::vector<ANeuralNetworksDevice*> mDevices;
626 uint32_t numDevices = 0;
627 EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
628 EXPECT_GE(numDevices, (uint32_t)1);
629
630 for (uint32_t i = 0; i < numDevices; i++) {
631 ANeuralNetworksDevice* device = nullptr;
632 EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
633 const char* buffer = nullptr;
634 int result = ANeuralNetworksDevice_getName(device, &buffer);
635 if (result == ANEURALNETWORKS_NO_ERROR && deviceName.compare(buffer) == 0) {
636 mDevices.push_back(device);
637 }
638 }
639 // In CPU only mode, DeviceManager::getDrivers() will not be able to
640 // provide the actual device list. We will not be able to find the test
641 // driver with specified deviceName.
642 if (!DeviceManager::get()->getUseCpuOnly()) {
643 EXPECT_EQ(mDevices.size(), (uint32_t)1);
644
645 int result = ANeuralNetworksCompilation_createForDevices(
646 model->getHandle(), mDevices.data(), mDevices.size(), &mCompilation);
647 EXPECT_EQ(result, ANEURALNETWORKS_NO_ERROR);
648 }
649 }
650 };
651
652 template <class DriverClass>
653 class ExecutionTestTemplate
654 : public ::testing::TestWithParam<std::tuple<V1_3::ErrorStatus, WrapperResult, bool>> {
655 public:
ExecutionTestTemplate()656 ExecutionTestTemplate()
657 : kName(toString(std::get<0>(GetParam()))),
658 kForceErrorStatus(std::get<0>(GetParam())),
659 kExpectResult(std::get<1>(GetParam())),
660 kUseIntrospectionAPI(std::get<2>(GetParam())),
661 mModel(makeModel()) {
662 if (kUseIntrospectionAPI) {
663 DeviceManager::get()->forTest_registerDevice(
664 nn::makeSharedDevice(kName, new DriverClass(kName.c_str(), kForceErrorStatus)));
665 mCompilation = TestIntrospectionCompilation(&mModel, kName);
666 } else {
667 mCompilation = TestCompilation<DriverClass>(&mModel, kName, kForceErrorStatus);
668 }
669 }
670
671 protected:
672 // Unit test method
673 // Set "reusable" to true to test reusable execution; Otherwise, test non-reusable execution.
674 void TestWait(bool reusable);
675
TearDown()676 virtual void TearDown() {
677 // Reinitialize the device list since Introspection API path altered it.
678 if (kUseIntrospectionAPI) {
679 DeviceManager::get()->forTest_reInitializeDeviceList();
680 }
681 }
682
getDimensionsWhileRunning(WrapperExecution & execution)683 void getDimensionsWhileRunning(WrapperExecution& execution) {
684 TestPreparedModelLatest::waitForExecutionToBegin();
685 // Cannot query dimensions while execution is running
686 std::vector<uint32_t> dimensions;
687 EXPECT_EQ(execution.getOutputOperandDimensions(0, &dimensions), WrapperResult::BAD_STATE);
688 }
689
690 const std::string kName;
691
692 // Allow dummying up the error status for execution. If
693 // kForceErrorStatus is NONE, then execution behaves normally (and
694 // sends back the actual execution status). Otherwise, don't
695 // bother to execute, and just send back kForceErrorStatus (as the
696 // execution status, not the launch status).
697 const V1_3::ErrorStatus kForceErrorStatus;
698
699 // What result do we expect from the execution? (The WrapperResult
700 // equivalent of kForceErrorStatus.)
701 const WrapperResult kExpectResult;
702
703 // Whether mCompilation is created via Introspection API or not.
704 const bool kUseIntrospectionAPI;
705
706 WrapperModel mModel;
707 WrapperCompilation mCompilation;
708
setInputOutput(WrapperExecution * execution)709 void setInputOutput(WrapperExecution* execution) {
710 mInputBuffer = kInputBuffer;
711 mOutputBuffer = kOutputBufferInitial;
712 ASSERT_EQ(execution->setInput(0, &mInputBuffer, sizeof(mInputBuffer)),
713 WrapperResult::NO_ERROR);
714 ASSERT_EQ(execution->setOutput(0, &mOutputBuffer, sizeof(mOutputBuffer)),
715 WrapperResult::NO_ERROR);
716 }
717
718 const float kInputBuffer = 3.14;
719 const float kOutputBufferInitial = 0;
720 float mInputBuffer;
721 float mOutputBuffer;
722 const float kOutputBufferExpected = 3;
723 const std::vector<uint32_t> kOutputDimensionsExpected = {1};
724
725 private:
makeModel()726 static WrapperModel makeModel() {
727 static const WrapperOperandType tensorType(WrapperType::TENSOR_FLOAT32, {1});
728
729 WrapperModel model;
730 uint32_t input = model.addOperand(&tensorType);
731 uint32_t output = model.addOperand(&tensorType);
732 model.addOperation(ANEURALNETWORKS_FLOOR, {input}, {output});
733 model.identifyInputsAndOutputs({input}, {output});
734 assert(model.finish() == WrapperResult::NO_ERROR);
735
736 return model;
737 }
738 };
739
computeHelper(bool reusable,const std::function<void ()> & compute)740 void computeHelper(bool reusable, const std::function<void()>& compute) {
741 {
742 SCOPED_TRACE(reusable ? "first time reusable" : "non-reusable");
743 compute();
744 }
745 if (reusable) {
746 SCOPED_TRACE("second time reusable");
747 compute();
748 }
749 }
750
751 template <class DriverClass>
TestWait(bool reusable)752 void ExecutionTestTemplate<DriverClass>::TestWait(bool reusable) {
753 SCOPED_TRACE(kName);
754 // Skip Introspection API tests when CPU only flag is forced on.
755 if (kUseIntrospectionAPI && DeviceManager::get()->getUseCpuOnly()) {
756 GTEST_SKIP();
757 }
758
759 ASSERT_EQ(mCompilation.finish(), WrapperResult::NO_ERROR);
760
761 {
762 SCOPED_TRACE("startCompute");
763 WrapperExecution execution(&mCompilation);
764 ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
765 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
766 const auto compute = [this, &execution] {
767 TestPreparedModelLatest::pauseExecutions(true);
768 WrapperEvent event;
769 ASSERT_EQ(execution.startCompute(&event), WrapperResult::NO_ERROR);
770 getDimensionsWhileRunning(execution);
771 TestPreparedModelLatest::pauseExecutions(false);
772 ASSERT_EQ(event.wait(), kExpectResult);
773 if (kExpectResult == WrapperResult::NO_ERROR) {
774 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
775 }
776 std::vector<uint32_t> dimensions;
777 if (kExpectResult == WrapperResult::NO_ERROR ||
778 kExpectResult == WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
779 // Only one output operand, hardcoded as index 0.
780 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
781 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
782 } else {
783 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
784 WrapperResult::BAD_STATE);
785 }
786 };
787 computeHelper(reusable, compute);
788 }
789 {
790 SCOPED_TRACE("compute");
791 WrapperExecution execution(&mCompilation);
792 ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
793 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
794 const auto compute = [this, &execution] {
795 TestPreparedModelLatest::pauseExecutions(true);
796 std::thread run([this, &execution] { EXPECT_EQ(execution.compute(), kExpectResult); });
797 getDimensionsWhileRunning(execution);
798 TestPreparedModelLatest::pauseExecutions(false);
799 run.join();
800 if (kExpectResult == WrapperResult::NO_ERROR) {
801 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
802 }
803 std::vector<uint32_t> dimensions;
804 if (kExpectResult == WrapperResult::NO_ERROR ||
805 kExpectResult == WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
806 // Only one output operand, hardcoded as index 0.
807 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
808 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
809 } else {
810 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
811 WrapperResult::BAD_STATE);
812 }
813 };
814 computeHelper(reusable, compute);
815 }
816 {
817 SCOPED_TRACE("burstCompute");
818
819 // TODO: If a burst API is added to nn::test_wrapper (e.g.,
820 // Execution::burstCompute()), then use that, rather than
821 // Execution::compute(WrapperExecution::ComputeMode::BURST).
822
823 WrapperExecution execution(&mCompilation);
824 ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
825 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
826 const auto compute = [this, &execution] {
827 TestPreparedModelLatest::pauseExecutions(true);
828 std::thread run([this, &execution] {
829 EXPECT_EQ(execution.compute(WrapperExecution::ComputeMode::BURST), kExpectResult);
830 });
831 getDimensionsWhileRunning(execution);
832 TestPreparedModelLatest::pauseExecutions(false);
833 run.join();
834 if (kExpectResult == WrapperResult::NO_ERROR) {
835 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
836 }
837 std::vector<uint32_t> dimensions;
838 if (kExpectResult == WrapperResult::NO_ERROR ||
839 kExpectResult == WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
840 // Only one output operand, hardcoded as index 0.
841 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
842 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
843 } else {
844 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
845 WrapperResult::BAD_STATE);
846 }
847 };
848 computeHelper(reusable, compute);
849 }
850 if (kExpectResult != WrapperResult::OUTPUT_INSUFFICIENT_SIZE) {
851 // computeWithDependencies doesn't support OUTPUT_INSUFFICIENT_SIZE
852 SCOPED_TRACE("computeWithDependencies");
853 WrapperExecution execution(&mCompilation);
854 ASSERT_EQ(execution.setReusable(reusable), WrapperResult::NO_ERROR);
855 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
856
857 const auto compute = [this, &execution] {
858 TestPreparedModelLatest::pauseExecutions(true);
859
860 WrapperEvent event;
861 // Note, due to the limitation of SampleDriver implementation, the call is synchronous.
862 // If the SampleDriver is updated to return real sync fence, this must be updated.
863 std::thread run([this, &execution, &event] {
864 EXPECT_EQ(execution.startComputeWithDependencies({}, 0, &event), kExpectResult);
865 });
866 getDimensionsWhileRunning(execution);
867 TestPreparedModelLatest::pauseExecutions(false);
868 run.join();
869 if (kExpectResult == WrapperResult::NO_ERROR) {
870 ASSERT_EQ(event.wait(), kExpectResult);
871 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
872 } else {
873 ASSERT_EQ(event.wait(), WrapperResult::UNEXPECTED_NULL);
874 }
875 std::vector<uint32_t> dimensions;
876 if (kExpectResult == WrapperResult::NO_ERROR) {
877 // Only one output operand, hardcoded as index 0.
878 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
879 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
880 } else {
881 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
882 WrapperResult::BAD_STATE);
883 }
884 };
885 computeHelper(reusable, compute);
886 }
887 }
888
889 auto kTestValues = ::testing::Values(
890 std::make_tuple(V1_3::ErrorStatus::NONE, WrapperResult::NO_ERROR,
891 /* kUseIntrospectionAPI */ false),
892 std::make_tuple(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, WrapperResult::UNAVAILABLE_DEVICE,
893 /* kUseIntrospectionAPI */ false),
894 std::make_tuple(V1_3::ErrorStatus::GENERAL_FAILURE, WrapperResult::OP_FAILED,
895 /* kUseIntrospectionAPI */ false),
896 std::make_tuple(V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE,
897 WrapperResult::OUTPUT_INSUFFICIENT_SIZE,
898 /* kUseIntrospectionAPI */ false),
899 std::make_tuple(V1_3::ErrorStatus::INVALID_ARGUMENT, WrapperResult::BAD_DATA,
900 /* kUseIntrospectionAPI */ false));
901
902 class ExecutionTest13 : public ExecutionTestTemplate<TestDriver13> {};
TEST_P(ExecutionTest13,Wait)903 TEST_P(ExecutionTest13, Wait) {
904 TestWait(/*reusable=*/false);
905 }
TEST_P(ExecutionTest13,WaitReusable)906 TEST_P(ExecutionTest13, WaitReusable) {
907 TestWait(/*reusable=*/true);
908 }
909 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest13, kTestValues);
910
911 class ExecutionTest12 : public ExecutionTestTemplate<TestDriver12> {};
TEST_P(ExecutionTest12,Wait)912 TEST_P(ExecutionTest12, Wait) {
913 TestWait(/*reusable=*/false);
914 }
TEST_P(ExecutionTest12,WaitReusable)915 TEST_P(ExecutionTest12, WaitReusable) {
916 TestWait(/*reusable=*/true);
917 }
918 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest12, kTestValues);
919
920 class ExecutionTest11 : public ExecutionTestTemplate<TestDriver11> {};
TEST_P(ExecutionTest11,Wait)921 TEST_P(ExecutionTest11, Wait) {
922 if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
923 TestWait(/*reusable=*/false);
924 }
TEST_P(ExecutionTest11,WaitReusable)925 TEST_P(ExecutionTest11, WaitReusable) {
926 if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
927 TestWait(/*reusable=*/true);
928 }
929 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest11, kTestValues);
930
931 class ExecutionTest10 : public ExecutionTestTemplate<TestDriver10> {};
TEST_P(ExecutionTest10,Wait)932 TEST_P(ExecutionTest10, Wait) {
933 if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
934 TestWait(/*reusable=*/false);
935 }
TEST_P(ExecutionTest10,WaitReusable)936 TEST_P(ExecutionTest10, WaitReusable) {
937 if (kForceErrorStatus == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
938 TestWait(/*reusable=*/true);
939 }
940 INSTANTIATE_TEST_SUITE_P(Flavor, ExecutionTest10, kTestValues);
941
942 auto kIntrospectionTestValues = ::testing::Values(
943 std::make_tuple(V1_3::ErrorStatus::NONE, WrapperResult::NO_ERROR,
944 /* kUseIntrospectionAPI */ true),
945 std::make_tuple(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, WrapperResult::UNAVAILABLE_DEVICE,
946 /* kUseIntrospectionAPI */ true),
947 std::make_tuple(V1_3::ErrorStatus::GENERAL_FAILURE, WrapperResult::OP_FAILED,
948 /* kUseIntrospectionAPI */ true),
949 std::make_tuple(V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE,
950 WrapperResult::OUTPUT_INSUFFICIENT_SIZE,
951 /* kUseIntrospectionAPI */ true),
952 std::make_tuple(V1_3::ErrorStatus::INVALID_ARGUMENT, WrapperResult::BAD_DATA,
953 /* kUseIntrospectionAPI */ true));
954
955 INSTANTIATE_TEST_SUITE_P(IntrospectionFlavor, ExecutionTest13, kIntrospectionTestValues);
956
957 } // namespace
958 } // namespace android
959