1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #undef NDEBUG
18
19 #include "Callbacks.h"
20 #include "CompilationBuilder.h"
21 #include "Manager.h"
22 #include "ModelBuilder.h"
23 #include "NeuralNetworks.h"
24 #include "SampleDriver.h"
25 #include "TestNeuralNetworksWrapper.h"
26 #include "ValidateHal.h"
27
28 #include <algorithm>
29 #include <cassert>
30 #include <vector>
31
32 #include <gtest/gtest.h>
33
34 namespace android {
35
36 using CompilationBuilder = nn::CompilationBuilder;
37 using Device = nn::Device;
38 using DeviceManager = nn::DeviceManager;
39 using HidlModel = hardware::neuralnetworks::V1_2::Model;
40 using HidlToken = hardware::hidl_array<uint8_t, ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN>;
41 using PreparedModelCallback = hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback;
42 using Result = nn::test_wrapper::Result;
43 using SampleDriver = nn::sample_driver::SampleDriver;
44 using WrapperCompilation = nn::test_wrapper::Compilation;
45 using WrapperEvent = nn::test_wrapper::Event;
46 using WrapperExecution = nn::test_wrapper::Execution;
47 using WrapperModel = nn::test_wrapper::Model;
48 using WrapperOperandType = nn::test_wrapper::OperandType;
49 using WrapperType = nn::test_wrapper::Type;
50
51 template <typename T>
52 using MQDescriptorSync = ::android::hardware::MQDescriptorSync<T>;
53
54 namespace {
55
56 const Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
57
58 // Wraps an V1_2::IPreparedModel to allow dummying up the execution status.
59 class TestPreparedModel12 : public V1_2::IPreparedModel {
60 public:
61 // If errorStatus is NONE, then execute behaves normally (and sends back
62 // the actual execution status). Otherwise, don't bother to execute, and
63 // just send back errorStatus (as the execution status, not the launch
64 // status).
TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel,ErrorStatus errorStatus)65 TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel, ErrorStatus errorStatus)
66 : mPreparedModelV1_0(preparedModel),
67 mPreparedModelV1_2(V1_2::IPreparedModel::castFrom(preparedModel).withDefault(nullptr)),
68 mErrorStatus(errorStatus) {}
69
execute(const Request & request,const sp<V1_0::IExecutionCallback> & callback)70 Return<ErrorStatus> execute(const Request& request,
71 const sp<V1_0::IExecutionCallback>& callback) override {
72 CHECK(mPreparedModelV1_0 != nullptr) << "V1_0 prepared model is nullptr.";
73 if (mErrorStatus == ErrorStatus::NONE) {
74 return mPreparedModelV1_0->execute(request, callback);
75 } else {
76 callback->notify(mErrorStatus);
77 return ErrorStatus::NONE;
78 }
79 }
80
execute_1_2(const Request & request,MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)81 Return<ErrorStatus> execute_1_2(const Request& request, MeasureTiming measure,
82 const sp<V1_2::IExecutionCallback>& callback) override {
83 CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
84 if (mErrorStatus == ErrorStatus::NONE) {
85 return mPreparedModelV1_2->execute_1_2(request, measure, callback);
86 } else if (mErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
87 OutputShape shape = {.dimensions = {1}, .isSufficient = false};
88 callback->notify_1_2(mErrorStatus, {shape}, kBadTiming);
89 return ErrorStatus::NONE;
90 } else {
91 callback->notify_1_2(mErrorStatus, {}, kBadTiming);
92 return ErrorStatus::NONE;
93 }
94 }
95
executeSynchronously(const Request & request,MeasureTiming measure,executeSynchronously_cb cb)96 Return<void> executeSynchronously(const Request& request, MeasureTiming measure,
97 executeSynchronously_cb cb) override {
98 CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
99 if (mErrorStatus == ErrorStatus::NONE) {
100 return mPreparedModelV1_2->executeSynchronously(
101 request, measure,
102 [&cb](ErrorStatus error, const hidl_vec<OutputShape>& outputShapes,
103 const Timing& timing) { cb(error, outputShapes, timing); });
104 } else if (mErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
105 OutputShape shape = {.dimensions = {1}, .isSufficient = false};
106 cb(mErrorStatus, {shape}, kBadTiming);
107 return Void();
108 } else {
109 cb(mErrorStatus, {}, kBadTiming);
110 return Void();
111 }
112 }
113
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)114 Return<void> configureExecutionBurst(
115 const sp<V1_2::IBurstCallback>& callback,
116 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
117 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
118 configureExecutionBurst_cb cb) override {
119 if (mErrorStatus == ErrorStatus::NONE) {
120 return mPreparedModelV1_2->configureExecutionBurst(callback, requestChannel,
121 resultChannel, cb);
122 } else {
123 cb(mErrorStatus, nullptr);
124 return Void();
125 }
126 }
127
128 private:
129 const sp<V1_0::IPreparedModel> mPreparedModelV1_0;
130 const sp<V1_2::IPreparedModel> mPreparedModelV1_2;
131 ErrorStatus mErrorStatus;
132 };
133
134 // Like TestPreparedModel12, but implementing 1.0
135 class TestPreparedModel10 : public V1_0::IPreparedModel {
136 public:
TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel,ErrorStatus errorStatus)137 TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel, ErrorStatus errorStatus)
138 : m12PreparedModel(new TestPreparedModel12(preparedModel, errorStatus)) {}
139
execute(const Request & request,const sp<V1_0::IExecutionCallback> & callback)140 Return<ErrorStatus> execute(const Request& request,
141 const sp<V1_0::IExecutionCallback>& callback) override {
142 return m12PreparedModel->execute(request, callback);
143 }
144
145 private:
146 const sp<V1_2::IPreparedModel> m12PreparedModel;
147 };
148
149 // Behaves like SampleDriver, except that it produces wrapped IPreparedModel.
150 class TestDriver12 : public SampleDriver {
151 public:
152 // Allow dummying up the error status for execution of all models
153 // prepared from this driver. If errorStatus is NONE, then
154 // execute behaves normally (and sends back the actual execution
155 // status). Otherwise, don't bother to execute, and just send
156 // back errorStatus (as the execution status, not the launch
157 // status).
TestDriver12(const std::string & name,ErrorStatus errorStatus)158 TestDriver12(const std::string& name, ErrorStatus errorStatus)
159 : SampleDriver(name.c_str()), mErrorStatus(errorStatus) {}
160
getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb)161 Return<void> getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb) override {
162 android::nn::initVLogMask();
163 const PerformanceInfo kPerf = {.execTime = 0.75f, .powerUsage = 0.75f};
164 Capabilities capabilities = {
165 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
166 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
167 .operandPerformance = nn::nonExtensionOperandPerformance(kPerf)};
168 _hidl_cb(ErrorStatus::NONE, capabilities);
169 return Void();
170 }
171
getSupportedOperations_1_2(const HidlModel & model,getSupportedOperations_1_2_cb cb)172 Return<void> getSupportedOperations_1_2(const HidlModel& model,
173 getSupportedOperations_1_2_cb cb) override {
174 if (nn::validateModel(model)) {
175 std::vector<bool> supported(model.operations.size(), true);
176 cb(ErrorStatus::NONE, supported);
177 } else {
178 std::vector<bool> supported;
179 cb(ErrorStatus::INVALID_ARGUMENT, supported);
180 }
181 return Void();
182 }
183
prepareModel_1_2(const HidlModel & model,ExecutionPreference preference,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,const HidlToken & token,const sp<IPreparedModelCallback> & actualCallback)184 Return<ErrorStatus> prepareModel_1_2(
185 const HidlModel& model, ExecutionPreference preference,
186 const hidl_vec<hidl_handle>& modelCache, const hidl_vec<hidl_handle>& dataCache,
187 const HidlToken& token, const sp<IPreparedModelCallback>& actualCallback) override {
188 sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
189 Return<ErrorStatus> prepareModelReturn = SampleDriver::prepareModel_1_2(
190 model, preference, modelCache, dataCache, token, localCallback);
191 if (!prepareModelReturn.isOkUnchecked()) {
192 return prepareModelReturn;
193 }
194 if (prepareModelReturn != ErrorStatus::NONE) {
195 actualCallback->notify_1_2(
196 localCallback->getStatus(),
197 V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
198 return prepareModelReturn;
199 }
200 localCallback->wait();
201 if (localCallback->getStatus() != ErrorStatus::NONE) {
202 actualCallback->notify_1_2(
203 localCallback->getStatus(),
204 V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
205 } else {
206 actualCallback->notify_1_2(
207 ErrorStatus::NONE,
208 new TestPreparedModel12(localCallback->getPreparedModel(), mErrorStatus));
209 }
210 return prepareModelReturn;
211 }
212
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)213 Return<ErrorStatus> prepareModel_1_1(
214 const V1_1::Model& model, ExecutionPreference preference,
215 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
216 sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
217 Return<ErrorStatus> prepareModelReturn =
218 SampleDriver::prepareModel_1_1(model, preference, localCallback);
219 if (!prepareModelReturn.isOkUnchecked()) {
220 return prepareModelReturn;
221 }
222 if (prepareModelReturn != ErrorStatus::NONE) {
223 actualCallback->notify(localCallback->getStatus(), localCallback->getPreparedModel());
224 return prepareModelReturn;
225 }
226 localCallback->wait();
227 if (localCallback->getStatus() != ErrorStatus::NONE) {
228 actualCallback->notify(localCallback->getStatus(), localCallback->getPreparedModel());
229 } else {
230 actualCallback->notify(
231 ErrorStatus::NONE,
232 new TestPreparedModel10(localCallback->getPreparedModel(), mErrorStatus));
233 }
234 return prepareModelReturn;
235 }
236
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)237 Return<ErrorStatus> prepareModel(
238 const V1_0::Model& model,
239 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
240 return prepareModel_1_1(nn::convertToV1_1(model), ExecutionPreference::FAST_SINGLE_ANSWER,
241 actualCallback);
242 }
243
244 private:
245 ErrorStatus mErrorStatus;
246 };
247
248 // Like TestDriver, but implementing 1.1
249 class TestDriver11 : public V1_1::IDevice {
250 public:
TestDriver11(const std::string & name,ErrorStatus errorStatus)251 TestDriver11(const std::string& name, ErrorStatus errorStatus)
252 : m12Driver(new TestDriver12(name, errorStatus)) {}
getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb)253 Return<void> getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb) override {
254 return m12Driver->getCapabilities_1_1(_hidl_cb);
255 }
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb _hidl_cb)256 Return<void> getSupportedOperations_1_1(const V1_1::Model& model,
257 getSupportedOperations_1_1_cb _hidl_cb) override {
258 return m12Driver->getSupportedOperations_1_1(model, _hidl_cb);
259 }
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)260 Return<ErrorStatus> prepareModel_1_1(
261 const V1_1::Model& model, ExecutionPreference preference,
262 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
263 return m12Driver->prepareModel_1_1(model, preference, actualCallback);
264 }
getStatus()265 Return<DeviceStatus> getStatus() override { return m12Driver->getStatus(); }
getCapabilities(getCapabilities_cb _hidl_cb)266 Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
267 return m12Driver->getCapabilities(_hidl_cb);
268 }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)269 Return<void> getSupportedOperations(const V1_0::Model& model,
270 getSupportedOperations_cb _hidl_cb) override {
271 return m12Driver->getSupportedOperations(model, _hidl_cb);
272 }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)273 Return<ErrorStatus> prepareModel(
274 const V1_0::Model& model,
275 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
276 return m12Driver->prepareModel(model, actualCallback);
277 }
278
279 private:
280 const sp<V1_2::IDevice> m12Driver;
281 };
282
283 // Like TestDriver, but implementing 1.0
284 class TestDriver10 : public V1_0::IDevice {
285 public:
TestDriver10(const std::string & name,ErrorStatus errorStatus)286 TestDriver10(const std::string& name, ErrorStatus errorStatus)
287 : m12Driver(new TestDriver12(name, errorStatus)) {}
getCapabilities(getCapabilities_cb _hidl_cb)288 Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
289 return m12Driver->getCapabilities(_hidl_cb);
290 }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)291 Return<void> getSupportedOperations(const V1_0::Model& model,
292 getSupportedOperations_cb _hidl_cb) override {
293 return m12Driver->getSupportedOperations(model, _hidl_cb);
294 }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)295 Return<ErrorStatus> prepareModel(
296 const V1_0::Model& model,
297 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
298 return m12Driver->prepareModel(model, actualCallback);
299 }
getStatus()300 Return<DeviceStatus> getStatus() override { return m12Driver->getStatus(); }
301
302 private:
303 const sp<V1_2::IDevice> m12Driver;
304 };
305
306 // This class adds some simple utilities on top of WrapperCompilation in order
307 // to provide access to certain features from CompilationBuilder that are not
308 // exposed by the base class.
309 template<typename DriverClass>
310 class TestCompilation : public WrapperCompilation {
311 public:
312 // Allow dummying up the error status for all executions from this
313 // compilation. If errorStatus is NONE, then execute behaves
314 // normally (and sends back the actual execution status).
315 // Otherwise, don't bother to execute, and just send back
316 // errorStatus (as the execution status, not the launch status).
TestCompilation(const WrapperModel * model,const std::string & deviceName,ErrorStatus errorStatus)317 TestCompilation(const WrapperModel* model, const std::string& deviceName,
318 ErrorStatus errorStatus) {
319 std::vector<std::shared_ptr<Device>> devices;
320 auto device = DeviceManager::forTest_makeDriverDevice(
321 deviceName, new DriverClass(deviceName, errorStatus));
322 devices.push_back(device);
323
324 nn::ModelBuilder* m = reinterpret_cast<nn::ModelBuilder*>(model->getHandle());
325 CompilationBuilder* c = nullptr;
326 int result = m->createCompilation(&c, devices);
327 EXPECT_EQ(result, 0);
328 // We need to ensure that we use our TestDriver and do not
329 // fall back to CPU. (If we allow CPU fallback, then when our
330 // TestDriver reports an execution failure, we'll re-execute
331 // on CPU, and will not see the failure.)
332 c->setPartitioning(DeviceManager::kPartitioningWithoutFallback);
333 mCompilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
334 }
335 };
336
337 // This class has roughly the same functionality as TestCompilation class.
338 // The major difference is that Introspection API is used to select the device.
339 class TestIntrospectionCompilation : public WrapperCompilation {
340 public:
TestIntrospectionCompilation(const WrapperModel * model,const std::string & deviceName)341 TestIntrospectionCompilation(const WrapperModel* model, const std::string& deviceName) {
342 std::vector<ANeuralNetworksDevice*> mDevices;
343 uint32_t numDevices = 0;
344 EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
345 EXPECT_GE(numDevices, (uint32_t)1);
346
347 for (uint32_t i = 0; i < numDevices; i++) {
348 ANeuralNetworksDevice* device = nullptr;
349 EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
350 const char* buffer = nullptr;
351 int result = ANeuralNetworksDevice_getName(device, &buffer);
352 if (result == ANEURALNETWORKS_NO_ERROR && deviceName.compare(buffer) == 0) {
353 mDevices.push_back(device);
354 }
355 }
356 // In CPU only mode, DeviceManager::getDrivers() will not be able to
357 // provide the actual device list. We will not be able to find the test
358 // driver with specified deviceName.
359 if (!DeviceManager::get()->getUseCpuOnly()) {
360 EXPECT_EQ(mDevices.size(), (uint32_t)1);
361
362 int result = ANeuralNetworksCompilation_createForDevices(
363 model->getHandle(), mDevices.data(), mDevices.size(), &mCompilation);
364 EXPECT_EQ(result, ANEURALNETWORKS_NO_ERROR);
365 }
366 }
367 };
368
369 template <class DriverClass>
370 class ExecutionTestTemplate
371 : public ::testing::TestWithParam<std::tuple<ErrorStatus, Result, bool>> {
372 public:
ExecutionTestTemplate()373 ExecutionTestTemplate()
374 : kName(toString(std::get<0>(GetParam()))),
375 kForceErrorStatus(std::get<0>(GetParam())),
376 kExpectResult(std::get<1>(GetParam())),
377 kUseIntrospectionAPI(std::get<2>(GetParam())),
378 mModel(makeModel()) {
379 if (kUseIntrospectionAPI) {
380 DeviceManager::get()->forTest_registerDevice(kName.c_str(),
381 new DriverClass(kName, kForceErrorStatus));
382 mCompilation = TestIntrospectionCompilation(&mModel, kName);
383 } else {
384 mCompilation = TestCompilation<DriverClass>(&mModel, kName, kForceErrorStatus);
385 }
386 }
387
388 protected:
389 // Unit test method
390 void TestWait();
391
TearDown()392 virtual void TearDown() {
393 // Reinitialize the device list since Introspection API path altered it.
394 if (kUseIntrospectionAPI) {
395 DeviceManager::get()->forTest_reInitializeDeviceList();
396 }
397 }
398
399 const std::string kName;
400
401 // Allow dummying up the error status for execution. If
402 // kForceErrorStatus is NONE, then execution behaves normally (and
403 // sends back the actual execution status). Otherwise, don't
404 // bother to execute, and just send back kForceErrorStatus (as the
405 // execution status, not the launch status).
406 const ErrorStatus kForceErrorStatus;
407
408 // What result do we expect from the execution? (The Result
409 // equivalent of kForceErrorStatus.)
410 const Result kExpectResult;
411
412 // Whether mCompilation is created via Introspection API or not.
413 const bool kUseIntrospectionAPI;
414
415 WrapperModel mModel;
416 WrapperCompilation mCompilation;
417
setInputOutput(WrapperExecution * execution)418 void setInputOutput(WrapperExecution* execution) {
419 mInputBuffer = kInputBuffer;
420 mOutputBuffer = kOutputBufferInitial;
421 ASSERT_EQ(execution->setInput(0, &mInputBuffer, sizeof(mInputBuffer)), Result::NO_ERROR);
422 ASSERT_EQ(execution->setOutput(0, &mOutputBuffer, sizeof(mOutputBuffer)), Result::NO_ERROR);
423 }
424
425 const float kInputBuffer = 3.14;
426 const float kOutputBufferInitial = 0;
427 float mInputBuffer;
428 float mOutputBuffer;
429 const float kOutputBufferExpected = 3;
430 const std::vector<uint32_t> kOutputDimensionsExpected = {1};
431
432 private:
makeModel()433 static WrapperModel makeModel() {
434 static const WrapperOperandType tensorType(WrapperType::TENSOR_FLOAT32, { 1 });
435
436 WrapperModel model;
437 uint32_t input = model.addOperand(&tensorType);
438 uint32_t output = model.addOperand(&tensorType);
439 model.addOperation(ANEURALNETWORKS_FLOOR, { input }, { output });
440 model.identifyInputsAndOutputs({ input }, { output } );
441 assert(model.finish() == Result::NO_ERROR);
442
443 return model;
444 }
445 };
446
TestWait()447 template<class DriverClass> void ExecutionTestTemplate<DriverClass>::TestWait() {
448 SCOPED_TRACE(kName);
449 // Skip Introspection API tests when CPU only flag is forced on.
450 if (kUseIntrospectionAPI && DeviceManager::get()->getUseCpuOnly()) {
451 GTEST_SKIP();
452 }
453
454 ASSERT_EQ(mCompilation.finish(), Result::NO_ERROR);
455
456 {
457 SCOPED_TRACE("startCompute");
458 WrapperExecution execution(&mCompilation);
459 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
460 WrapperEvent event;
461 ASSERT_EQ(execution.startCompute(&event), Result::NO_ERROR);
462 ASSERT_EQ(event.wait(), kExpectResult);
463 if (kExpectResult == Result::NO_ERROR) {
464 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
465 }
466 std::vector<uint32_t> dimensions;
467 if (kExpectResult == Result::OUTPUT_INSUFFICIENT_SIZE) {
468 // Only one output operand, hardcoded as index 0.
469 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
470 Result::OUTPUT_INSUFFICIENT_SIZE);
471 } else {
472 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), Result::NO_ERROR);
473 }
474 if (kExpectResult == Result::NO_ERROR ||
475 kExpectResult == Result::OUTPUT_INSUFFICIENT_SIZE) {
476 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
477 }
478 }
479 {
480 SCOPED_TRACE("compute");
481 WrapperExecution execution(&mCompilation);
482 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
483 ASSERT_EQ(execution.compute(), kExpectResult);
484 if (kExpectResult == Result::NO_ERROR) {
485 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
486 }
487 std::vector<uint32_t> dimensions;
488 if (kExpectResult == Result::OUTPUT_INSUFFICIENT_SIZE) {
489 // Only one output operand, hardcoded as index 0.
490 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
491 Result::OUTPUT_INSUFFICIENT_SIZE);
492 } else {
493 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), Result::NO_ERROR);
494 }
495 if (kExpectResult == Result::NO_ERROR ||
496 kExpectResult == Result::OUTPUT_INSUFFICIENT_SIZE) {
497 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
498 }
499 }
500 }
501
502 auto kTestValues = ::testing::Values(
503 std::make_tuple(ErrorStatus::NONE, Result::NO_ERROR, /* kUseIntrospectionAPI */ false),
504 std::make_tuple(ErrorStatus::DEVICE_UNAVAILABLE, Result::UNAVAILABLE_DEVICE,
505 /* kUseIntrospectionAPI */ false),
506 std::make_tuple(ErrorStatus::GENERAL_FAILURE, Result::OP_FAILED,
507 /* kUseIntrospectionAPI */ false),
508 std::make_tuple(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, Result::OUTPUT_INSUFFICIENT_SIZE,
509 /* kUseIntrospectionAPI */ false),
510 std::make_tuple(ErrorStatus::INVALID_ARGUMENT, Result::BAD_DATA,
511 /* kUseIntrospectionAPI */ false));
512
513 class ExecutionTest12 : public ExecutionTestTemplate<TestDriver12> {};
TEST_P(ExecutionTest12,Wait)514 TEST_P(ExecutionTest12, Wait) {
515 TestWait();
516 }
517 INSTANTIATE_TEST_CASE_P(Flavor, ExecutionTest12, kTestValues);
518
519 class ExecutionTest11 : public ExecutionTestTemplate<TestDriver11> {};
TEST_P(ExecutionTest11,Wait)520 TEST_P(ExecutionTest11, Wait) {
521 if (kForceErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
522 TestWait();
523 }
524 INSTANTIATE_TEST_CASE_P(Flavor, ExecutionTest11, kTestValues);
525
526 class ExecutionTest10 : public ExecutionTestTemplate<TestDriver10> {};
TEST_P(ExecutionTest10,Wait)527 TEST_P(ExecutionTest10, Wait) {
528 if (kForceErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
529 TestWait();
530 }
531 INSTANTIATE_TEST_CASE_P(Flavor, ExecutionTest10, kTestValues);
532
533 auto kIntrospectionTestValues = ::testing::Values(
534 std::make_tuple(ErrorStatus::NONE, Result::NO_ERROR, /* kUseIntrospectionAPI */ true),
535 std::make_tuple(ErrorStatus::DEVICE_UNAVAILABLE, Result::UNAVAILABLE_DEVICE,
536 /* kUseIntrospectionAPI */ true),
537 std::make_tuple(ErrorStatus::GENERAL_FAILURE, Result::OP_FAILED,
538 /* kUseIntrospectionAPI */ true),
539 std::make_tuple(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, Result::OUTPUT_INSUFFICIENT_SIZE,
540 /* kUseIntrospectionAPI */ true),
541 std::make_tuple(ErrorStatus::INVALID_ARGUMENT, Result::BAD_DATA,
542 /* kUseIntrospectionAPI */ true));
543
544 INSTANTIATE_TEST_CASE_P(IntrospectionFlavor, ExecutionTest12, kIntrospectionTestValues);
545
546 } // namespace
547 } // namespace android
548