1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #undef NDEBUG
18 
19 #include "Callbacks.h"
20 #include "CompilationBuilder.h"
21 #include "Manager.h"
22 #include "ModelBuilder.h"
23 #include "NeuralNetworks.h"
24 #include "SampleDriver.h"
25 #include "TestNeuralNetworksWrapper.h"
26 #include "ValidateHal.h"
27 
28 #include <algorithm>
29 #include <cassert>
30 #include <vector>
31 
32 #include <gtest/gtest.h>
33 
34 namespace android {
35 
36 using CompilationBuilder = nn::CompilationBuilder;
37 using Device = nn::Device;
38 using DeviceManager = nn::DeviceManager;
39 using HidlModel = hardware::neuralnetworks::V1_2::Model;
40 using HidlToken = hardware::hidl_array<uint8_t, ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN>;
41 using PreparedModelCallback = hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback;
42 using Result = nn::test_wrapper::Result;
43 using SampleDriver = nn::sample_driver::SampleDriver;
44 using WrapperCompilation = nn::test_wrapper::Compilation;
45 using WrapperEvent = nn::test_wrapper::Event;
46 using WrapperExecution = nn::test_wrapper::Execution;
47 using WrapperModel = nn::test_wrapper::Model;
48 using WrapperOperandType = nn::test_wrapper::OperandType;
49 using WrapperType = nn::test_wrapper::Type;
50 
51 template <typename T>
52 using MQDescriptorSync = ::android::hardware::MQDescriptorSync<T>;
53 
54 namespace {
55 
56 const Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
57 
58 // Wraps an V1_2::IPreparedModel to allow dummying up the execution status.
59 class TestPreparedModel12 : public V1_2::IPreparedModel {
60    public:
61     // If errorStatus is NONE, then execute behaves normally (and sends back
62     // the actual execution status).  Otherwise, don't bother to execute, and
63     // just send back errorStatus (as the execution status, not the launch
64     // status).
TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel,ErrorStatus errorStatus)65     TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel, ErrorStatus errorStatus)
66         : mPreparedModelV1_0(preparedModel),
67           mPreparedModelV1_2(V1_2::IPreparedModel::castFrom(preparedModel).withDefault(nullptr)),
68           mErrorStatus(errorStatus) {}
69 
execute(const Request & request,const sp<V1_0::IExecutionCallback> & callback)70     Return<ErrorStatus> execute(const Request& request,
71                                 const sp<V1_0::IExecutionCallback>& callback) override {
72         CHECK(mPreparedModelV1_0 != nullptr) << "V1_0 prepared model is nullptr.";
73         if (mErrorStatus == ErrorStatus::NONE) {
74             return mPreparedModelV1_0->execute(request, callback);
75         } else {
76             callback->notify(mErrorStatus);
77             return ErrorStatus::NONE;
78         }
79     }
80 
execute_1_2(const Request & request,MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)81     Return<ErrorStatus> execute_1_2(const Request& request, MeasureTiming measure,
82                                     const sp<V1_2::IExecutionCallback>& callback) override {
83         CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
84         if (mErrorStatus == ErrorStatus::NONE) {
85             return mPreparedModelV1_2->execute_1_2(request, measure, callback);
86         } else if (mErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
87             OutputShape shape = {.dimensions = {1}, .isSufficient = false};
88             callback->notify_1_2(mErrorStatus, {shape}, kBadTiming);
89             return ErrorStatus::NONE;
90         } else {
91             callback->notify_1_2(mErrorStatus, {}, kBadTiming);
92             return ErrorStatus::NONE;
93         }
94     }
95 
executeSynchronously(const Request & request,MeasureTiming measure,executeSynchronously_cb cb)96     Return<void> executeSynchronously(const Request& request, MeasureTiming measure,
97                                       executeSynchronously_cb cb) override {
98         CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
99         if (mErrorStatus == ErrorStatus::NONE) {
100             return mPreparedModelV1_2->executeSynchronously(
101                     request, measure,
102                     [&cb](ErrorStatus error, const hidl_vec<OutputShape>& outputShapes,
103                           const Timing& timing) { cb(error, outputShapes, timing); });
104         } else if (mErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
105             OutputShape shape = {.dimensions = {1}, .isSufficient = false};
106             cb(mErrorStatus, {shape}, kBadTiming);
107             return Void();
108         } else {
109             cb(mErrorStatus, {}, kBadTiming);
110             return Void();
111         }
112     }
113 
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)114     Return<void> configureExecutionBurst(
115             const sp<V1_2::IBurstCallback>& callback,
116             const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
117             const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
118             configureExecutionBurst_cb cb) override {
119         if (mErrorStatus == ErrorStatus::NONE) {
120             return mPreparedModelV1_2->configureExecutionBurst(callback, requestChannel,
121                                                                resultChannel, cb);
122         } else {
123             cb(mErrorStatus, nullptr);
124             return Void();
125         }
126     }
127 
128    private:
129     const sp<V1_0::IPreparedModel> mPreparedModelV1_0;
130     const sp<V1_2::IPreparedModel> mPreparedModelV1_2;
131     ErrorStatus mErrorStatus;
132 };
133 
134 // Like TestPreparedModel12, but implementing 1.0
135 class TestPreparedModel10 : public V1_0::IPreparedModel {
136    public:
TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel,ErrorStatus errorStatus)137     TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel, ErrorStatus errorStatus)
138         : m12PreparedModel(new TestPreparedModel12(preparedModel, errorStatus)) {}
139 
execute(const Request & request,const sp<V1_0::IExecutionCallback> & callback)140     Return<ErrorStatus> execute(const Request& request,
141                                 const sp<V1_0::IExecutionCallback>& callback) override {
142         return m12PreparedModel->execute(request, callback);
143     }
144 
145    private:
146     const sp<V1_2::IPreparedModel> m12PreparedModel;
147 };
148 
149 // Behaves like SampleDriver, except that it produces wrapped IPreparedModel.
150 class TestDriver12 : public SampleDriver {
151    public:
152     // Allow dummying up the error status for execution of all models
153     // prepared from this driver.  If errorStatus is NONE, then
154     // execute behaves normally (and sends back the actual execution
155     // status).  Otherwise, don't bother to execute, and just send
156     // back errorStatus (as the execution status, not the launch
157     // status).
TestDriver12(const std::string & name,ErrorStatus errorStatus)158     TestDriver12(const std::string& name, ErrorStatus errorStatus)
159         : SampleDriver(name.c_str()), mErrorStatus(errorStatus) {}
160 
getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb)161     Return<void> getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb) override {
162         android::nn::initVLogMask();
163         const PerformanceInfo kPerf = {.execTime = 0.75f, .powerUsage = 0.75f};
164         Capabilities capabilities = {
165                 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
166                 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
167                 .operandPerformance = nn::nonExtensionOperandPerformance(kPerf)};
168         _hidl_cb(ErrorStatus::NONE, capabilities);
169         return Void();
170     }
171 
getSupportedOperations_1_2(const HidlModel & model,getSupportedOperations_1_2_cb cb)172     Return<void> getSupportedOperations_1_2(const HidlModel& model,
173                                             getSupportedOperations_1_2_cb cb) override {
174         if (nn::validateModel(model)) {
175             std::vector<bool> supported(model.operations.size(), true);
176             cb(ErrorStatus::NONE, supported);
177         } else {
178             std::vector<bool> supported;
179             cb(ErrorStatus::INVALID_ARGUMENT, supported);
180         }
181         return Void();
182     }
183 
prepareModel_1_2(const HidlModel & model,ExecutionPreference preference,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,const HidlToken & token,const sp<IPreparedModelCallback> & actualCallback)184     Return<ErrorStatus> prepareModel_1_2(
185             const HidlModel& model, ExecutionPreference preference,
186             const hidl_vec<hidl_handle>& modelCache, const hidl_vec<hidl_handle>& dataCache,
187             const HidlToken& token, const sp<IPreparedModelCallback>& actualCallback) override {
188         sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
189         Return<ErrorStatus> prepareModelReturn = SampleDriver::prepareModel_1_2(
190                 model, preference, modelCache, dataCache, token, localCallback);
191         if (!prepareModelReturn.isOkUnchecked()) {
192             return prepareModelReturn;
193         }
194         if (prepareModelReturn != ErrorStatus::NONE) {
195             actualCallback->notify_1_2(
196                     localCallback->getStatus(),
197                     V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
198             return prepareModelReturn;
199         }
200         localCallback->wait();
201         if (localCallback->getStatus() != ErrorStatus::NONE) {
202             actualCallback->notify_1_2(
203                     localCallback->getStatus(),
204                     V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
205         } else {
206             actualCallback->notify_1_2(
207                     ErrorStatus::NONE,
208                     new TestPreparedModel12(localCallback->getPreparedModel(), mErrorStatus));
209         }
210         return prepareModelReturn;
211     }
212 
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)213     Return<ErrorStatus> prepareModel_1_1(
214             const V1_1::Model& model, ExecutionPreference preference,
215             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
216         sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
217         Return<ErrorStatus> prepareModelReturn =
218                 SampleDriver::prepareModel_1_1(model, preference, localCallback);
219         if (!prepareModelReturn.isOkUnchecked()) {
220             return prepareModelReturn;
221         }
222         if (prepareModelReturn != ErrorStatus::NONE) {
223             actualCallback->notify(localCallback->getStatus(), localCallback->getPreparedModel());
224             return prepareModelReturn;
225         }
226         localCallback->wait();
227         if (localCallback->getStatus() != ErrorStatus::NONE) {
228             actualCallback->notify(localCallback->getStatus(), localCallback->getPreparedModel());
229         } else {
230             actualCallback->notify(
231                     ErrorStatus::NONE,
232                     new TestPreparedModel10(localCallback->getPreparedModel(), mErrorStatus));
233         }
234         return prepareModelReturn;
235     }
236 
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)237     Return<ErrorStatus> prepareModel(
238             const V1_0::Model& model,
239             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
240         return prepareModel_1_1(nn::convertToV1_1(model), ExecutionPreference::FAST_SINGLE_ANSWER,
241                                 actualCallback);
242     }
243 
244 private:
245     ErrorStatus mErrorStatus;
246 };
247 
248 // Like TestDriver, but implementing 1.1
249 class TestDriver11 : public V1_1::IDevice {
250    public:
TestDriver11(const std::string & name,ErrorStatus errorStatus)251     TestDriver11(const std::string& name, ErrorStatus errorStatus)
252         : m12Driver(new TestDriver12(name, errorStatus)) {}
getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb)253     Return<void> getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb) override {
254         return m12Driver->getCapabilities_1_1(_hidl_cb);
255     }
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb _hidl_cb)256     Return<void> getSupportedOperations_1_1(const V1_1::Model& model,
257                                             getSupportedOperations_1_1_cb _hidl_cb) override {
258         return m12Driver->getSupportedOperations_1_1(model, _hidl_cb);
259     }
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)260     Return<ErrorStatus> prepareModel_1_1(
261             const V1_1::Model& model, ExecutionPreference preference,
262             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
263         return m12Driver->prepareModel_1_1(model, preference, actualCallback);
264     }
getStatus()265     Return<DeviceStatus> getStatus() override { return m12Driver->getStatus(); }
getCapabilities(getCapabilities_cb _hidl_cb)266     Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
267         return m12Driver->getCapabilities(_hidl_cb);
268     }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)269     Return<void> getSupportedOperations(const V1_0::Model& model,
270                                         getSupportedOperations_cb _hidl_cb) override {
271         return m12Driver->getSupportedOperations(model, _hidl_cb);
272     }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)273     Return<ErrorStatus> prepareModel(
274             const V1_0::Model& model,
275             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
276         return m12Driver->prepareModel(model, actualCallback);
277     }
278 
279    private:
280     const sp<V1_2::IDevice> m12Driver;
281 };
282 
283 // Like TestDriver, but implementing 1.0
284 class TestDriver10 : public V1_0::IDevice {
285    public:
TestDriver10(const std::string & name,ErrorStatus errorStatus)286     TestDriver10(const std::string& name, ErrorStatus errorStatus)
287         : m12Driver(new TestDriver12(name, errorStatus)) {}
getCapabilities(getCapabilities_cb _hidl_cb)288     Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
289         return m12Driver->getCapabilities(_hidl_cb);
290     }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)291     Return<void> getSupportedOperations(const V1_0::Model& model,
292                                         getSupportedOperations_cb _hidl_cb) override {
293         return m12Driver->getSupportedOperations(model, _hidl_cb);
294     }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)295     Return<ErrorStatus> prepareModel(
296             const V1_0::Model& model,
297             const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
298         return m12Driver->prepareModel(model, actualCallback);
299     }
getStatus()300     Return<DeviceStatus> getStatus() override { return m12Driver->getStatus(); }
301 
302    private:
303     const sp<V1_2::IDevice> m12Driver;
304 };
305 
306 // This class adds some simple utilities on top of WrapperCompilation in order
307 // to provide access to certain features from CompilationBuilder that are not
308 // exposed by the base class.
309 template<typename DriverClass>
310 class TestCompilation : public WrapperCompilation {
311 public:
312     // Allow dummying up the error status for all executions from this
313     // compilation.  If errorStatus is NONE, then execute behaves
314     // normally (and sends back the actual execution status).
315     // Otherwise, don't bother to execute, and just send back
316     // errorStatus (as the execution status, not the launch status).
TestCompilation(const WrapperModel * model,const std::string & deviceName,ErrorStatus errorStatus)317     TestCompilation(const WrapperModel* model, const std::string& deviceName,
318                     ErrorStatus errorStatus) {
319         std::vector<std::shared_ptr<Device>> devices;
320         auto device = DeviceManager::forTest_makeDriverDevice(
321                 deviceName, new DriverClass(deviceName, errorStatus));
322         devices.push_back(device);
323 
324         nn::ModelBuilder* m = reinterpret_cast<nn::ModelBuilder*>(model->getHandle());
325         CompilationBuilder* c = nullptr;
326         int result = m->createCompilation(&c, devices);
327         EXPECT_EQ(result, 0);
328         // We need to ensure that we use our TestDriver and do not
329         // fall back to CPU.  (If we allow CPU fallback, then when our
330         // TestDriver reports an execution failure, we'll re-execute
331         // on CPU, and will not see the failure.)
332         c->setPartitioning(DeviceManager::kPartitioningWithoutFallback);
333         mCompilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
334     }
335 };
336 
337 // This class has roughly the same functionality as TestCompilation class.
338 // The major difference is that Introspection API is used to select the device.
339 class TestIntrospectionCompilation : public WrapperCompilation {
340    public:
TestIntrospectionCompilation(const WrapperModel * model,const std::string & deviceName)341     TestIntrospectionCompilation(const WrapperModel* model, const std::string& deviceName) {
342         std::vector<ANeuralNetworksDevice*> mDevices;
343         uint32_t numDevices = 0;
344         EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
345         EXPECT_GE(numDevices, (uint32_t)1);
346 
347         for (uint32_t i = 0; i < numDevices; i++) {
348             ANeuralNetworksDevice* device = nullptr;
349             EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
350             const char* buffer = nullptr;
351             int result = ANeuralNetworksDevice_getName(device, &buffer);
352             if (result == ANEURALNETWORKS_NO_ERROR && deviceName.compare(buffer) == 0) {
353                 mDevices.push_back(device);
354             }
355         }
356         // In CPU only mode, DeviceManager::getDrivers() will not be able to
357         // provide the actual device list. We will not be able to find the test
358         // driver with specified deviceName.
359         if (!DeviceManager::get()->getUseCpuOnly()) {
360             EXPECT_EQ(mDevices.size(), (uint32_t)1);
361 
362             int result = ANeuralNetworksCompilation_createForDevices(
363                     model->getHandle(), mDevices.data(), mDevices.size(), &mCompilation);
364             EXPECT_EQ(result, ANEURALNETWORKS_NO_ERROR);
365         }
366     }
367 };
368 
369 template <class DriverClass>
370 class ExecutionTestTemplate
371     : public ::testing::TestWithParam<std::tuple<ErrorStatus, Result, bool>> {
372    public:
ExecutionTestTemplate()373     ExecutionTestTemplate()
374         : kName(toString(std::get<0>(GetParam()))),
375           kForceErrorStatus(std::get<0>(GetParam())),
376           kExpectResult(std::get<1>(GetParam())),
377           kUseIntrospectionAPI(std::get<2>(GetParam())),
378           mModel(makeModel()) {
379         if (kUseIntrospectionAPI) {
380             DeviceManager::get()->forTest_registerDevice(kName.c_str(),
381                                                          new DriverClass(kName, kForceErrorStatus));
382             mCompilation = TestIntrospectionCompilation(&mModel, kName);
383         } else {
384             mCompilation = TestCompilation<DriverClass>(&mModel, kName, kForceErrorStatus);
385         }
386     }
387 
388    protected:
389     // Unit test method
390     void TestWait();
391 
TearDown()392     virtual void TearDown() {
393         // Reinitialize the device list since Introspection API path altered it.
394         if (kUseIntrospectionAPI) {
395             DeviceManager::get()->forTest_reInitializeDeviceList();
396         }
397     }
398 
399     const std::string kName;
400 
401     // Allow dummying up the error status for execution.  If
402     // kForceErrorStatus is NONE, then execution behaves normally (and
403     // sends back the actual execution status).  Otherwise, don't
404     // bother to execute, and just send back kForceErrorStatus (as the
405     // execution status, not the launch status).
406     const ErrorStatus kForceErrorStatus;
407 
408     // What result do we expect from the execution?  (The Result
409     // equivalent of kForceErrorStatus.)
410     const Result kExpectResult;
411 
412     // Whether mCompilation is created via Introspection API or not.
413     const bool kUseIntrospectionAPI;
414 
415     WrapperModel mModel;
416     WrapperCompilation mCompilation;
417 
setInputOutput(WrapperExecution * execution)418     void setInputOutput(WrapperExecution* execution) {
419         mInputBuffer = kInputBuffer;
420         mOutputBuffer = kOutputBufferInitial;
421         ASSERT_EQ(execution->setInput(0, &mInputBuffer, sizeof(mInputBuffer)), Result::NO_ERROR);
422         ASSERT_EQ(execution->setOutput(0, &mOutputBuffer, sizeof(mOutputBuffer)), Result::NO_ERROR);
423     }
424 
425     const float kInputBuffer = 3.14;
426     const float kOutputBufferInitial = 0;
427     float mInputBuffer;
428     float mOutputBuffer;
429     const float kOutputBufferExpected = 3;
430     const std::vector<uint32_t> kOutputDimensionsExpected = {1};
431 
432    private:
makeModel()433     static WrapperModel makeModel() {
434         static const WrapperOperandType tensorType(WrapperType::TENSOR_FLOAT32, { 1 });
435 
436         WrapperModel model;
437         uint32_t input = model.addOperand(&tensorType);
438         uint32_t output = model.addOperand(&tensorType);
439         model.addOperation(ANEURALNETWORKS_FLOOR, { input }, { output });
440         model.identifyInputsAndOutputs({ input }, { output } );
441         assert(model.finish() == Result::NO_ERROR);
442 
443         return model;
444     }
445 };
446 
TestWait()447 template<class DriverClass> void ExecutionTestTemplate<DriverClass>::TestWait() {
448     SCOPED_TRACE(kName);
449     // Skip Introspection API tests when CPU only flag is forced on.
450     if (kUseIntrospectionAPI && DeviceManager::get()->getUseCpuOnly()) {
451         GTEST_SKIP();
452     }
453 
454     ASSERT_EQ(mCompilation.finish(), Result::NO_ERROR);
455 
456     {
457         SCOPED_TRACE("startCompute");
458         WrapperExecution execution(&mCompilation);
459         ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
460         WrapperEvent event;
461         ASSERT_EQ(execution.startCompute(&event), Result::NO_ERROR);
462         ASSERT_EQ(event.wait(), kExpectResult);
463         if (kExpectResult == Result::NO_ERROR) {
464             ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
465         }
466         std::vector<uint32_t> dimensions;
467         if (kExpectResult == Result::OUTPUT_INSUFFICIENT_SIZE) {
468             // Only one output operand, hardcoded as index 0.
469             ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
470                       Result::OUTPUT_INSUFFICIENT_SIZE);
471         } else {
472             ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), Result::NO_ERROR);
473         }
474         if (kExpectResult == Result::NO_ERROR ||
475             kExpectResult == Result::OUTPUT_INSUFFICIENT_SIZE) {
476             ASSERT_EQ(dimensions, kOutputDimensionsExpected);
477         }
478     }
479     {
480         SCOPED_TRACE("compute");
481         WrapperExecution execution(&mCompilation);
482         ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
483         ASSERT_EQ(execution.compute(), kExpectResult);
484         if (kExpectResult == Result::NO_ERROR) {
485             ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
486         }
487         std::vector<uint32_t> dimensions;
488         if (kExpectResult == Result::OUTPUT_INSUFFICIENT_SIZE) {
489             // Only one output operand, hardcoded as index 0.
490             ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions),
491                       Result::OUTPUT_INSUFFICIENT_SIZE);
492         } else {
493             ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), Result::NO_ERROR);
494         }
495         if (kExpectResult == Result::NO_ERROR ||
496             kExpectResult == Result::OUTPUT_INSUFFICIENT_SIZE) {
497             ASSERT_EQ(dimensions, kOutputDimensionsExpected);
498         }
499     }
500 }
501 
502 auto kTestValues = ::testing::Values(
503         std::make_tuple(ErrorStatus::NONE, Result::NO_ERROR, /* kUseIntrospectionAPI */ false),
504         std::make_tuple(ErrorStatus::DEVICE_UNAVAILABLE, Result::UNAVAILABLE_DEVICE,
505                         /* kUseIntrospectionAPI */ false),
506         std::make_tuple(ErrorStatus::GENERAL_FAILURE, Result::OP_FAILED,
507                         /* kUseIntrospectionAPI */ false),
508         std::make_tuple(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, Result::OUTPUT_INSUFFICIENT_SIZE,
509                         /* kUseIntrospectionAPI */ false),
510         std::make_tuple(ErrorStatus::INVALID_ARGUMENT, Result::BAD_DATA,
511                         /* kUseIntrospectionAPI */ false));
512 
513 class ExecutionTest12 : public ExecutionTestTemplate<TestDriver12> {};
TEST_P(ExecutionTest12,Wait)514 TEST_P(ExecutionTest12, Wait) {
515     TestWait();
516 }
517 INSTANTIATE_TEST_CASE_P(Flavor, ExecutionTest12, kTestValues);
518 
519 class ExecutionTest11 : public ExecutionTestTemplate<TestDriver11> {};
TEST_P(ExecutionTest11,Wait)520 TEST_P(ExecutionTest11, Wait) {
521     if (kForceErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
522     TestWait();
523 }
524 INSTANTIATE_TEST_CASE_P(Flavor, ExecutionTest11, kTestValues);
525 
526 class ExecutionTest10 : public ExecutionTestTemplate<TestDriver10> {};
TEST_P(ExecutionTest10,Wait)527 TEST_P(ExecutionTest10, Wait) {
528     if (kForceErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
529     TestWait();
530 }
531 INSTANTIATE_TEST_CASE_P(Flavor, ExecutionTest10, kTestValues);
532 
533 auto kIntrospectionTestValues = ::testing::Values(
534         std::make_tuple(ErrorStatus::NONE, Result::NO_ERROR, /* kUseIntrospectionAPI */ true),
535         std::make_tuple(ErrorStatus::DEVICE_UNAVAILABLE, Result::UNAVAILABLE_DEVICE,
536                         /* kUseIntrospectionAPI */ true),
537         std::make_tuple(ErrorStatus::GENERAL_FAILURE, Result::OP_FAILED,
538                         /* kUseIntrospectionAPI */ true),
539         std::make_tuple(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, Result::OUTPUT_INSUFFICIENT_SIZE,
540                         /* kUseIntrospectionAPI */ true),
541         std::make_tuple(ErrorStatus::INVALID_ARGUMENT, Result::BAD_DATA,
542                         /* kUseIntrospectionAPI */ true));
543 
544 INSTANTIATE_TEST_CASE_P(IntrospectionFlavor, ExecutionTest12, kIntrospectionTestValues);
545 
546 }  // namespace
547 }  // namespace android
548