1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <gtest/gtest.h>
18
19 #include <android-base/scopeguard.h>
20
21 #include <algorithm>
22 #include <atomic>
23 #include <cassert>
24 #include <memory>
25 #include <string>
26 #include <thread>
27 #include <tuple>
28 #include <vector>
29
30 #include "Callbacks.h"
31 #include "CompilationBuilder.h"
32 #include "HalInterfaces.h"
33 #include "Manager.h"
34 #include "ModelBuilder.h"
35 #include "NeuralNetworks.h"
36 #include "SampleDriver.h"
37 #include "TestNeuralNetworksWrapper.h"
38 #include "Utils.h"
39 #include "ValidateHal.h"
40
41 namespace android {
42
43 using namespace nn::hal;
44 using CompilationBuilder = nn::CompilationBuilder;
45 using Device = nn::Device;
46 using DeviceManager = nn::DeviceManager;
47 using HidlModel = V1_3::Model;
48 using PreparedModelCallback = nn::PreparedModelCallback;
49 using Result = nn::test_wrapper::Result;
50 using SampleDriver = nn::sample_driver::SampleDriver;
51 using WrapperCompilation = nn::test_wrapper::Compilation;
52 using WrapperEvent = nn::test_wrapper::Event;
53 using WrapperExecution = nn::test_wrapper::Execution;
54 using WrapperModel = nn::test_wrapper::Model;
55 using WrapperOperandType = nn::test_wrapper::OperandType;
56 using WrapperType = nn::test_wrapper::Type;
57 using nn::convertToV1_0;
58
59 template <typename T>
60 using MQDescriptorSync = hardware::MQDescriptorSync<T>;
61
62 namespace {
63
64 const Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
65
66 // Wraps the latest version of IPreparedModel to allow dummying up the execution status,
67 // and control when the execution finishes.
68 class TestPreparedModelLatest : public IPreparedModel {
69 public:
70 // If errorStatus is NONE, then execute behaves normally (and sends back
71 // the actual execution status). Otherwise, don't bother to execute, and
72 // just send back errorStatus (as the execution status, not the launch
73 // status).
TestPreparedModelLatest(sp<V1_0::IPreparedModel> preparedModel,ErrorStatus errorStatus)74 TestPreparedModelLatest(sp<V1_0::IPreparedModel> preparedModel, ErrorStatus errorStatus)
75 : mPreparedModelV1_0(preparedModel),
76 mPreparedModelV1_2(V1_2::IPreparedModel::castFrom(preparedModel).withDefault(nullptr)),
77 mPreparedModelV1_3(V1_3::IPreparedModel::castFrom(preparedModel).withDefault(nullptr)),
78 mErrorStatus(errorStatus) {}
79
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)80 Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
81 const sp<V1_0::IExecutionCallback>& callback) override {
82 CHECK(mPreparedModelV1_0 != nullptr) << "V1_0 prepared model is nullptr.";
83 std::thread([this, &request, &callback] {
84 dummyExecution();
85 if (mErrorStatus == ErrorStatus::NONE) {
86 // Note that we lose the actual launch status.
87 (void)mPreparedModelV1_0->execute(request, callback);
88 } else {
89 callback->notify(convertToV1_0(mErrorStatus));
90 }
91 }).detach();
92 return V1_0::ErrorStatus::NONE;
93 }
94
execute_1_2(const V1_0::Request & request,MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)95 Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, MeasureTiming measure,
96 const sp<V1_2::IExecutionCallback>& callback) override {
97 CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
98 std::thread([this, &request, measure, &callback] {
99 dummyExecution();
100 if (mErrorStatus == ErrorStatus::NONE) {
101 // Note that we lose the actual launch status.
102 (void)mPreparedModelV1_2->execute_1_2(request, measure, callback);
103 } else if (mErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
104 OutputShape shape = {.dimensions = {1}, .isSufficient = false};
105 callback->notify_1_2(convertToV1_0(mErrorStatus), {shape}, kBadTiming);
106 } else {
107 callback->notify_1_2(convertToV1_0(mErrorStatus), {}, kBadTiming);
108 }
109 }).detach();
110 return V1_0::ErrorStatus::NONE;
111 }
112
execute_1_3(const V1_3::Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)113 Return<V1_3::ErrorStatus> execute_1_3(const V1_3::Request& request, MeasureTiming measure,
114 const OptionalTimePoint& deadline,
115 const OptionalTimeoutDuration& loopTimeoutDuration,
116 const sp<V1_3::IExecutionCallback>& callback) override {
117 CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
118 std::thread([this, &request, measure, &deadline, &loopTimeoutDuration, &callback] {
119 dummyExecution();
120 if (mErrorStatus == ErrorStatus::NONE) {
121 // Note that we lose the actual launch status.
122 (void)mPreparedModelV1_3->execute_1_3(request, measure, deadline,
123 loopTimeoutDuration, callback);
124 } else if (mErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
125 OutputShape shape = {.dimensions = {1}, .isSufficient = false};
126 callback->notify_1_3(mErrorStatus, {shape}, kBadTiming);
127 } else {
128 callback->notify_1_3(mErrorStatus, {}, kBadTiming);
129 }
130 }).detach();
131 return V1_3::ErrorStatus::NONE;
132 }
133
executeSynchronously(const V1_0::Request & request,MeasureTiming measure,executeSynchronously_cb cb)134 Return<void> executeSynchronously(const V1_0::Request& request, MeasureTiming measure,
135 executeSynchronously_cb cb) override {
136 CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
137 dummyExecution();
138 if (mErrorStatus == ErrorStatus::NONE) {
139 return mPreparedModelV1_2->executeSynchronously(request, measure, cb);
140 } else if (mErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
141 OutputShape shape = {.dimensions = {1}, .isSufficient = false};
142 cb(convertToV1_0(mErrorStatus), {shape}, kBadTiming);
143 return Void();
144 } else {
145 cb(convertToV1_0(mErrorStatus), {}, kBadTiming);
146 return Void();
147 }
148 }
149
executeSynchronously_1_3(const V1_3::Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)150 Return<void> executeSynchronously_1_3(const V1_3::Request& request, MeasureTiming measure,
151 const OptionalTimePoint& deadline,
152 const OptionalTimeoutDuration& loopTimeoutDuration,
153 executeSynchronously_1_3_cb cb) override {
154 CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
155 dummyExecution();
156 if (mErrorStatus == ErrorStatus::NONE) {
157 return mPreparedModelV1_3->executeSynchronously_1_3(request, measure, deadline,
158 loopTimeoutDuration, cb);
159 } else if (mErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
160 OutputShape shape = {.dimensions = {1}, .isSufficient = false};
161 cb(mErrorStatus, {shape}, kBadTiming);
162 return Void();
163 } else {
164 cb(mErrorStatus, {}, kBadTiming);
165 return Void();
166 }
167 }
168
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)169 Return<void> configureExecutionBurst(
170 const sp<V1_2::IBurstCallback>& callback,
171 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
172 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
173 configureExecutionBurst_cb cb) override {
174 CHECK(mPreparedModelV1_2 != nullptr) << "V1_2 prepared model is nullptr.";
175 if (mErrorStatus == ErrorStatus::NONE) {
176 return mPreparedModelV1_2->configureExecutionBurst(callback, requestChannel,
177 resultChannel, cb);
178 } else {
179 cb(convertToV1_0(mErrorStatus), nullptr);
180 return Void();
181 }
182 }
183
184 // Note, due to the limitation of SampleDriver implementation, the call is
185 // synchronous. The test code that exercises this implementation of
186 // SampleDriver is written with that in mind. Therefore, this
187 // implementation is synchronous also. If the SampleDriver is updated to
188 // return real sync fence, this must be updated.
executeFenced(const V1_3::Request & request,const hidl_vec<hidl_handle> & waitFor,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,const OptionalTimeoutDuration & duration,executeFenced_cb cb)189 Return<void> executeFenced(const V1_3::Request& request, const hidl_vec<hidl_handle>& waitFor,
190 MeasureTiming measure, const OptionalTimePoint& deadline,
191 const OptionalTimeoutDuration& loopTimeoutDuration,
192 const OptionalTimeoutDuration& duration,
193 executeFenced_cb cb) override {
194 CHECK(mPreparedModelV1_3 != nullptr) << "V1_3 prepared model is nullptr.";
195 CHECK(mErrorStatus != ErrorStatus::OUTPUT_INSUFFICIENT_SIZE)
196 << "executeFenced does not support dynamic output shape";
197 dummyExecution();
198 if (mErrorStatus == ErrorStatus::NONE) {
199 return mPreparedModelV1_3->executeFenced(request, waitFor, measure, deadline,
200 loopTimeoutDuration, duration, cb);
201 } else {
202 // Due to the limitations of the SampleDriver, all failures look
203 // like launch failures. If the SampleDriver is updated to return
204 // real sync fences, this must be updated.
205 cb(mErrorStatus, hidl_handle(nullptr), nullptr);
206 }
207 return Void();
208 }
209
210 // We can place the TestPreparedModelLatest system in a "pause" mode where
211 // no execution will complete until the system is taken out of that mode.
212 // Initially, the system is not in that mode.
pauseExecutions(bool v)213 static void pauseExecutions(bool v) { mPauseExecutions.store(v); }
214
215 // This function is only guaranteed to work in the following pattern:
216 // - pauseExecutions(true);
217 // - // launch execution
218 // - // thread A: waitForExecutionToBegin()
219 // - // thread B: pauseExecutions(false);
waitForExecutionToBegin()220 static void waitForExecutionToBegin() {
221 CHECK(mPauseExecutions.load());
222 while (mExecutionsInFlight.load()) {
223 }
224 }
225
226 private:
227 const sp<V1_0::IPreparedModel> mPreparedModelV1_0;
228 const sp<V1_2::IPreparedModel> mPreparedModelV1_2;
229 const sp<V1_3::IPreparedModel> mPreparedModelV1_3;
230 ErrorStatus mErrorStatus;
231
232 static std::atomic<bool> mPauseExecutions;
233 static std::atomic<unsigned int> mExecutionsInFlight;
234
dummyExecution()235 static void dummyExecution() {
236 CHECK_EQ(mExecutionsInFlight.fetch_add(1), 0u) << "We do not support concurrent executions";
237 while (mPauseExecutions.load()) {
238 }
239 mExecutionsInFlight.fetch_sub(1);
240 }
241 };
242 std::atomic<bool> TestPreparedModelLatest::mPauseExecutions = false;
243 std::atomic<unsigned int> TestPreparedModelLatest::mExecutionsInFlight = 0;
244
245 using TestPreparedModel13 = TestPreparedModelLatest;
246
247 // Like TestPreparedModelLatest, but implementing 1.2
248 class TestPreparedModel12 : public V1_2::IPreparedModel {
249 public:
TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel,ErrorStatus errorStatus)250 TestPreparedModel12(sp<V1_0::IPreparedModel> preparedModel, ErrorStatus errorStatus)
251 : mLatestPreparedModel(new TestPreparedModelLatest(preparedModel, errorStatus)) {}
252
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)253 Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
254 const sp<V1_0::IExecutionCallback>& callback) override {
255 return mLatestPreparedModel->execute(request, callback);
256 }
257
execute_1_2(const V1_0::Request & request,MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)258 Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, MeasureTiming measure,
259 const sp<V1_2::IExecutionCallback>& callback) override {
260 return mLatestPreparedModel->execute_1_2(request, measure, callback);
261 }
262
executeSynchronously(const V1_0::Request & request,MeasureTiming measure,executeSynchronously_cb cb)263 Return<void> executeSynchronously(const V1_0::Request& request, MeasureTiming measure,
264 executeSynchronously_cb cb) override {
265 return mLatestPreparedModel->executeSynchronously(request, measure, cb);
266 }
267
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)268 Return<void> configureExecutionBurst(
269 const sp<V1_2::IBurstCallback>& callback,
270 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
271 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
272 configureExecutionBurst_cb cb) override {
273 return mLatestPreparedModel->configureExecutionBurst(callback, requestChannel,
274 resultChannel, cb);
275 }
276
277 private:
278 const sp<IPreparedModel> mLatestPreparedModel;
279 };
280
281 // Like TestPreparedModelLatest, but implementing 1.0
282 class TestPreparedModel10 : public V1_0::IPreparedModel {
283 public:
TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel,ErrorStatus errorStatus)284 TestPreparedModel10(sp<V1_0::IPreparedModel> preparedModel, ErrorStatus errorStatus)
285 : mLatestPreparedModel(new TestPreparedModelLatest(preparedModel, errorStatus)) {}
286
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)287 Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
288 const sp<V1_0::IExecutionCallback>& callback) override {
289 return mLatestPreparedModel->execute(request, callback);
290 }
291
292 private:
293 const sp<IPreparedModel> mLatestPreparedModel;
294 };
295
296 // Behaves like SampleDriver, except that it produces wrapped IPreparedModel.
297 class TestDriver13 : public SampleDriver {
298 public:
299 // Allow dummying up the error status for execution of all models
300 // prepared from this driver. If errorStatus is NONE, then
301 // execute behaves normally (and sends back the actual execution
302 // status). Otherwise, don't bother to execute, and just send
303 // back errorStatus (as the execution status, not the launch
304 // status).
TestDriver13(const std::string & name,ErrorStatus errorStatus)305 TestDriver13(const std::string& name, ErrorStatus errorStatus)
306 : SampleDriver(name.c_str()), mErrorStatus(errorStatus) {}
307
getCapabilities_1_3(getCapabilities_1_3_cb _hidl_cb)308 Return<void> getCapabilities_1_3(getCapabilities_1_3_cb _hidl_cb) override {
309 android::nn::initVLogMask();
310 const PerformanceInfo kPerf = {.execTime = 0.75f, .powerUsage = 0.75f};
311 Capabilities capabilities = {
312 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
313 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
314 .operandPerformance =
315 nn::nonExtensionOperandPerformance<nn::HalVersion::V1_3>(kPerf),
316 .ifPerformance = kPerf,
317 .whilePerformance = kPerf};
318 _hidl_cb(V1_3::ErrorStatus::NONE, capabilities);
319 return Void();
320 }
321
getSupportedOperations_1_3(const HidlModel & model,getSupportedOperations_1_3_cb cb)322 Return<void> getSupportedOperations_1_3(const HidlModel& model,
323 getSupportedOperations_1_3_cb cb) override {
324 if (nn::validateModel(model)) {
325 std::vector<bool> supported(model.main.operations.size(), true);
326 cb(V1_3::ErrorStatus::NONE, supported);
327 } else {
328 cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {});
329 }
330 return Void();
331 }
332
prepareModel_1_3(const HidlModel & model,ExecutionPreference preference,Priority priority,const OptionalTimePoint & deadline,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,const CacheToken & token,const sp<V1_3::IPreparedModelCallback> & actualCallback)333 Return<V1_3::ErrorStatus> prepareModel_1_3(
334 const HidlModel& model, ExecutionPreference preference, Priority priority,
335 const OptionalTimePoint& deadline, const hidl_vec<hidl_handle>& modelCache,
336 const hidl_vec<hidl_handle>& dataCache, const CacheToken& token,
337 const sp<V1_3::IPreparedModelCallback>& actualCallback) override {
338 sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
339 Return<V1_3::ErrorStatus> prepareModelReturn = SampleDriver::prepareModel_1_3(
340 model, preference, priority, deadline, modelCache, dataCache, token, localCallback);
341 if (!prepareModelReturn.isOkUnchecked()) {
342 return prepareModelReturn;
343 }
344 if (prepareModelReturn != ErrorStatus::NONE) {
345 actualCallback->notify_1_3(
346 localCallback->getStatus(),
347 V1_3::IPreparedModel::castFrom(localCallback->getPreparedModel()));
348 return prepareModelReturn;
349 }
350 localCallback->wait();
351 if (localCallback->getStatus() != ErrorStatus::NONE) {
352 actualCallback->notify_1_3(
353 localCallback->getStatus(),
354 V1_3::IPreparedModel::castFrom(localCallback->getPreparedModel()));
355 } else {
356 actualCallback->notify_1_3(
357 V1_3::ErrorStatus::NONE,
358 new TestPreparedModel13(localCallback->getPreparedModel(), mErrorStatus));
359 }
360 return prepareModelReturn;
361 }
362
prepareModel_1_2(const V1_2::Model & model,ExecutionPreference preference,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,const CacheToken & token,const sp<V1_2::IPreparedModelCallback> & actualCallback)363 Return<V1_0::ErrorStatus> prepareModel_1_2(
364 const V1_2::Model& model, ExecutionPreference preference,
365 const hidl_vec<hidl_handle>& modelCache, const hidl_vec<hidl_handle>& dataCache,
366 const CacheToken& token,
367 const sp<V1_2::IPreparedModelCallback>& actualCallback) override {
368 sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
369 Return<V1_0::ErrorStatus> prepareModelReturn = SampleDriver::prepareModel_1_2(
370 model, preference, modelCache, dataCache, token, localCallback);
371 if (!prepareModelReturn.isOkUnchecked()) {
372 return prepareModelReturn;
373 }
374 if (prepareModelReturn != V1_0::ErrorStatus::NONE) {
375 actualCallback->notify_1_2(
376 convertToV1_0(localCallback->getStatus()),
377 V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
378 return prepareModelReturn;
379 }
380 localCallback->wait();
381 if (localCallback->getStatus() != ErrorStatus::NONE) {
382 actualCallback->notify_1_2(
383 convertToV1_0(localCallback->getStatus()),
384 V1_2::IPreparedModel::castFrom(localCallback->getPreparedModel()));
385 } else {
386 actualCallback->notify_1_2(
387 V1_0::ErrorStatus::NONE,
388 new TestPreparedModel12(localCallback->getPreparedModel(), mErrorStatus));
389 }
390 return prepareModelReturn;
391 }
392
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)393 Return<V1_0::ErrorStatus> prepareModel_1_1(
394 const V1_1::Model& model, ExecutionPreference preference,
395 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
396 sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
397 Return<V1_0::ErrorStatus> prepareModelReturn =
398 SampleDriver::prepareModel_1_1(model, preference, localCallback);
399 if (!prepareModelReturn.isOkUnchecked()) {
400 return prepareModelReturn;
401 }
402 if (prepareModelReturn != V1_0::ErrorStatus::NONE) {
403 actualCallback->notify(convertToV1_0(localCallback->getStatus()),
404 localCallback->getPreparedModel());
405 return prepareModelReturn;
406 }
407 localCallback->wait();
408 if (localCallback->getStatus() != ErrorStatus::NONE) {
409 actualCallback->notify(convertToV1_0(localCallback->getStatus()),
410 localCallback->getPreparedModel());
411 } else {
412 actualCallback->notify(
413 V1_0::ErrorStatus::NONE,
414 new TestPreparedModel10(localCallback->getPreparedModel(), mErrorStatus));
415 }
416 return prepareModelReturn;
417 }
418
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)419 Return<V1_0::ErrorStatus> prepareModel(
420 const V1_0::Model& model,
421 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
422 return prepareModel_1_1(nn::convertToV1_1(model), ExecutionPreference::FAST_SINGLE_ANSWER,
423 actualCallback);
424 }
425
426 private:
427 ErrorStatus mErrorStatus;
428 };
429
430 // Like TestDriver, but implementing 1.2
431 class TestDriver12 : public V1_2::IDevice {
432 public:
TestDriver12(const std::string & name,ErrorStatus errorStatus)433 TestDriver12(const std::string& name, ErrorStatus errorStatus)
434 : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb)435 Return<void> getCapabilities_1_2(getCapabilities_1_2_cb _hidl_cb) override {
436 return mLatestDriver->getCapabilities_1_2(_hidl_cb);
437 }
getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb)438 Return<void> getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb) override {
439 return mLatestDriver->getCapabilities_1_1(_hidl_cb);
440 }
getCapabilities(getCapabilities_cb _hidl_cb)441 Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
442 return mLatestDriver->getCapabilities(_hidl_cb);
443 }
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb _hidl_cb)444 Return<void> getSupportedOperations_1_2(const V1_2::Model& model,
445 getSupportedOperations_1_2_cb _hidl_cb) override {
446 return mLatestDriver->getSupportedOperations_1_2(model, _hidl_cb);
447 }
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb _hidl_cb)448 Return<void> getSupportedOperations_1_1(const V1_1::Model& model,
449 getSupportedOperations_1_1_cb _hidl_cb) override {
450 return mLatestDriver->getSupportedOperations_1_1(model, _hidl_cb);
451 }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)452 Return<void> getSupportedOperations(const V1_0::Model& model,
453 getSupportedOperations_cb _hidl_cb) override {
454 return mLatestDriver->getSupportedOperations(model, _hidl_cb);
455 }
prepareModel_1_2(const V1_2::Model & model,ExecutionPreference preference,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,const CacheToken & token,const sp<V1_2::IPreparedModelCallback> & actualCallback)456 Return<V1_0::ErrorStatus> prepareModel_1_2(
457 const V1_2::Model& model, ExecutionPreference preference,
458 const hidl_vec<hidl_handle>& modelCache, const hidl_vec<hidl_handle>& dataCache,
459 const CacheToken& token,
460 const sp<V1_2::IPreparedModelCallback>& actualCallback) override {
461 return mLatestDriver->prepareModel_1_2(model, preference, modelCache, dataCache, token,
462 actualCallback);
463 }
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)464 Return<V1_0::ErrorStatus> prepareModel_1_1(
465 const V1_1::Model& model, ExecutionPreference preference,
466 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
467 return mLatestDriver->prepareModel_1_1(model, preference, actualCallback);
468 }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)469 Return<V1_0::ErrorStatus> prepareModel(
470 const V1_0::Model& model,
471 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
472 return mLatestDriver->prepareModel(model, actualCallback);
473 }
getStatus()474 Return<DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
getVersionString(getVersionString_cb _hidl_cb)475 Return<void> getVersionString(getVersionString_cb _hidl_cb) override {
476 return mLatestDriver->getVersionString(_hidl_cb);
477 }
getType(getType_cb _hidl_cb)478 Return<void> getType(getType_cb _hidl_cb) override { return mLatestDriver->getType(_hidl_cb); }
getSupportedExtensions(getSupportedExtensions_cb _hidl_cb)479 Return<void> getSupportedExtensions(getSupportedExtensions_cb _hidl_cb) {
480 return mLatestDriver->getSupportedExtensions(_hidl_cb);
481 }
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb _hidl_cb)482 Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb _hidl_cb) {
483 return mLatestDriver->getNumberOfCacheFilesNeeded(_hidl_cb);
484 }
prepareModelFromCache(const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,const CacheToken & token,const sp<V1_2::IPreparedModelCallback> & callback)485 Return<V1_0::ErrorStatus> prepareModelFromCache(
486 const hidl_vec<hidl_handle>& modelCache, const hidl_vec<hidl_handle>& dataCache,
487 const CacheToken& token, const sp<V1_2::IPreparedModelCallback>& callback) {
488 return mLatestDriver->prepareModelFromCache(modelCache, dataCache, token, callback);
489 }
490
491 private:
492 const sp<V1_3::IDevice> mLatestDriver;
493 };
494
495 // Like TestDriver, but implementing 1.1
496 class TestDriver11 : public V1_1::IDevice {
497 public:
TestDriver11(const std::string & name,ErrorStatus errorStatus)498 TestDriver11(const std::string& name, ErrorStatus errorStatus)
499 : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb)500 Return<void> getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb) override {
501 return mLatestDriver->getCapabilities_1_1(_hidl_cb);
502 }
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb _hidl_cb)503 Return<void> getSupportedOperations_1_1(const V1_1::Model& model,
504 getSupportedOperations_1_1_cb _hidl_cb) override {
505 return mLatestDriver->getSupportedOperations_1_1(model, _hidl_cb);
506 }
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & actualCallback)507 Return<V1_0::ErrorStatus> prepareModel_1_1(
508 const V1_1::Model& model, ExecutionPreference preference,
509 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
510 return mLatestDriver->prepareModel_1_1(model, preference, actualCallback);
511 }
getStatus()512 Return<DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
getCapabilities(getCapabilities_cb _hidl_cb)513 Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
514 return mLatestDriver->getCapabilities(_hidl_cb);
515 }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)516 Return<void> getSupportedOperations(const V1_0::Model& model,
517 getSupportedOperations_cb _hidl_cb) override {
518 return mLatestDriver->getSupportedOperations(model, _hidl_cb);
519 }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)520 Return<V1_0::ErrorStatus> prepareModel(
521 const V1_0::Model& model,
522 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
523 return mLatestDriver->prepareModel(model, actualCallback);
524 }
525
526 private:
527 const sp<V1_3::IDevice> mLatestDriver;
528 };
529
530 // Like TestDriver, but implementing 1.0
531 class TestDriver10 : public V1_0::IDevice {
532 public:
TestDriver10(const std::string & name,ErrorStatus errorStatus)533 TestDriver10(const std::string& name, ErrorStatus errorStatus)
534 : mLatestDriver(new TestDriver13(name, errorStatus)) {}
getCapabilities(getCapabilities_cb _hidl_cb)535 Return<void> getCapabilities(getCapabilities_cb _hidl_cb) override {
536 return mLatestDriver->getCapabilities(_hidl_cb);
537 }
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb _hidl_cb)538 Return<void> getSupportedOperations(const V1_0::Model& model,
539 getSupportedOperations_cb _hidl_cb) override {
540 return mLatestDriver->getSupportedOperations(model, _hidl_cb);
541 }
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & actualCallback)542 Return<V1_0::ErrorStatus> prepareModel(
543 const V1_0::Model& model,
544 const sp<V1_0::IPreparedModelCallback>& actualCallback) override {
545 return mLatestDriver->prepareModel(model, actualCallback);
546 }
getStatus()547 Return<DeviceStatus> getStatus() override { return mLatestDriver->getStatus(); }
548
549 private:
550 const sp<V1_3::IDevice> mLatestDriver;
551 };
552
553 // This class adds some simple utilities on top of WrapperCompilation in order
554 // to provide access to certain features from CompilationBuilder that are not
555 // exposed by the base class.
556 template <typename DriverClass>
557 class TestCompilation : public WrapperCompilation {
558 public:
559 // Allow dummying up the error status for all executions from this
560 // compilation. If errorStatus is NONE, then execute behaves
561 // normally (and sends back the actual execution status).
562 // Otherwise, don't bother to execute, and just send back
563 // errorStatus (as the execution status, not the launch status).
TestCompilation(const WrapperModel * model,const std::string & deviceName,ErrorStatus errorStatus)564 TestCompilation(const WrapperModel* model, const std::string& deviceName,
565 ErrorStatus errorStatus) {
566 std::vector<std::shared_ptr<Device>> devices;
567 auto device = DeviceManager::forTest_makeDriverDevice(
568 deviceName, new DriverClass(deviceName, errorStatus));
569 devices.push_back(device);
570
571 nn::ModelBuilder* m = reinterpret_cast<nn::ModelBuilder*>(model->getHandle());
572 CompilationBuilder* c = nullptr;
573 int result = m->createCompilation(&c, devices);
574 EXPECT_EQ(result, 0);
575 // We need to ensure that we use our TestDriver and do not
576 // fall back to CPU. (If we allow CPU fallback, then when our
577 // TestDriver reports an execution failure, we'll re-execute
578 // on CPU, and will not see the failure.)
579 c->setPartitioning(DeviceManager::kPartitioningWithoutFallback);
580 mCompilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
581 }
582 };
583
584 // This class has roughly the same functionality as TestCompilation class.
585 // The major difference is that Introspection API is used to select the device.
586 class TestIntrospectionCompilation : public WrapperCompilation {
587 public:
TestIntrospectionCompilation(const WrapperModel * model,const std::string & deviceName)588 TestIntrospectionCompilation(const WrapperModel* model, const std::string& deviceName) {
589 std::vector<ANeuralNetworksDevice*> mDevices;
590 uint32_t numDevices = 0;
591 EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
592 EXPECT_GE(numDevices, (uint32_t)1);
593
594 for (uint32_t i = 0; i < numDevices; i++) {
595 ANeuralNetworksDevice* device = nullptr;
596 EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
597 const char* buffer = nullptr;
598 int result = ANeuralNetworksDevice_getName(device, &buffer);
599 if (result == ANEURALNETWORKS_NO_ERROR && deviceName.compare(buffer) == 0) {
600 mDevices.push_back(device);
601 }
602 }
603 // In CPU only mode, DeviceManager::getDrivers() will not be able to
604 // provide the actual device list. We will not be able to find the test
605 // driver with specified deviceName.
606 if (!DeviceManager::get()->getUseCpuOnly()) {
607 EXPECT_EQ(mDevices.size(), (uint32_t)1);
608
609 int result = ANeuralNetworksCompilation_createForDevices(
610 model->getHandle(), mDevices.data(), mDevices.size(), &mCompilation);
611 EXPECT_EQ(result, ANEURALNETWORKS_NO_ERROR);
612 }
613 }
614 };
615
616 template <class DriverClass>
617 class ExecutionTestTemplate
618 : public ::testing::TestWithParam<std::tuple<ErrorStatus, Result, bool>> {
619 public:
ExecutionTestTemplate()620 ExecutionTestTemplate()
621 : kName(toString(std::get<0>(GetParam()))),
622 kForceErrorStatus(std::get<0>(GetParam())),
623 kExpectResult(std::get<1>(GetParam())),
624 kUseIntrospectionAPI(std::get<2>(GetParam())),
625 mModel(makeModel()) {
626 if (kUseIntrospectionAPI) {
627 DeviceManager::get()->forTest_registerDevice(kName.c_str(),
628 new DriverClass(kName, kForceErrorStatus));
629 mCompilation = TestIntrospectionCompilation(&mModel, kName);
630 } else {
631 mCompilation = TestCompilation<DriverClass>(&mModel, kName, kForceErrorStatus);
632 }
633 }
634
635 protected:
636 // Unit test method
637 void TestWait();
638
TearDown()639 virtual void TearDown() {
640 // Reinitialize the device list since Introspection API path altered it.
641 if (kUseIntrospectionAPI) {
642 DeviceManager::get()->forTest_reInitializeDeviceList();
643 }
644 }
645
646 const std::string kName;
647
648 // Allow dummying up the error status for execution. If
649 // kForceErrorStatus is NONE, then execution behaves normally (and
650 // sends back the actual execution status). Otherwise, don't
651 // bother to execute, and just send back kForceErrorStatus (as the
652 // execution status, not the launch status).
653 const ErrorStatus kForceErrorStatus;
654
655 // What result do we expect from the execution? (The Result
656 // equivalent of kForceErrorStatus.)
657 const Result kExpectResult;
658
659 // Whether mCompilation is created via Introspection API or not.
660 const bool kUseIntrospectionAPI;
661
662 WrapperModel mModel;
663 WrapperCompilation mCompilation;
664
setInputOutput(WrapperExecution * execution)665 void setInputOutput(WrapperExecution* execution) {
666 mInputBuffer = kInputBuffer;
667 mOutputBuffer = kOutputBufferInitial;
668 ASSERT_EQ(execution->setInput(0, &mInputBuffer, sizeof(mInputBuffer)), Result::NO_ERROR);
669 ASSERT_EQ(execution->setOutput(0, &mOutputBuffer, sizeof(mOutputBuffer)), Result::NO_ERROR);
670 }
671
672 const float kInputBuffer = 3.14;
673 const float kOutputBufferInitial = 0;
674 float mInputBuffer;
675 float mOutputBuffer;
676 const float kOutputBufferExpected = 3;
677 const std::vector<uint32_t> kOutputDimensionsExpected = {1};
678
679 private:
makeModel()680 static WrapperModel makeModel() {
681 static const WrapperOperandType tensorType(WrapperType::TENSOR_FLOAT32, {1});
682
683 WrapperModel model;
684 uint32_t input = model.addOperand(&tensorType);
685 uint32_t output = model.addOperand(&tensorType);
686 model.addOperation(ANEURALNETWORKS_FLOOR, {input}, {output});
687 model.identifyInputsAndOutputs({input}, {output});
688 assert(model.finish() == Result::NO_ERROR);
689
690 return model;
691 }
692 };
693
694 template <class DriverClass>
TestWait()695 void ExecutionTestTemplate<DriverClass>::TestWait() {
696 SCOPED_TRACE(kName);
697 // Skip Introspection API tests when CPU only flag is forced on.
698 if (kUseIntrospectionAPI && DeviceManager::get()->getUseCpuOnly()) {
699 GTEST_SKIP();
700 }
701
702 ASSERT_EQ(mCompilation.finish(), Result::NO_ERROR);
703
704 const auto getDimensionsWhileRunning = [](WrapperExecution& execution) {
705 TestPreparedModelLatest::waitForExecutionToBegin();
706 // Cannot query dimensions while execution is running
707 std::vector<uint32_t> dimensions;
708 EXPECT_EQ(execution.getOutputOperandDimensions(0, &dimensions), Result::BAD_STATE);
709 };
710
711 {
712 SCOPED_TRACE("startCompute");
713 WrapperExecution execution(&mCompilation);
714 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
715 TestPreparedModelLatest::pauseExecutions(true);
716 WrapperEvent event;
717 ASSERT_EQ(execution.startCompute(&event), Result::NO_ERROR);
718 getDimensionsWhileRunning(execution);
719 TestPreparedModelLatest::pauseExecutions(false);
720 ASSERT_EQ(event.wait(), kExpectResult);
721 if (kExpectResult == Result::NO_ERROR) {
722 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
723 }
724 std::vector<uint32_t> dimensions;
725 if (kExpectResult == Result::NO_ERROR ||
726 kExpectResult == Result::OUTPUT_INSUFFICIENT_SIZE) {
727 // Only one output operand, hardcoded as index 0.
728 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
729 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
730 } else {
731 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), Result::BAD_STATE);
732 }
733 }
734 {
735 SCOPED_TRACE("compute");
736 WrapperExecution execution(&mCompilation);
737 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
738 TestPreparedModelLatest::pauseExecutions(true);
739 std::thread run([this, &execution] { EXPECT_EQ(execution.compute(), kExpectResult); });
740 getDimensionsWhileRunning(execution);
741 TestPreparedModelLatest::pauseExecutions(false);
742 run.join();
743 if (kExpectResult == Result::NO_ERROR) {
744 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
745 }
746 std::vector<uint32_t> dimensions;
747 if (kExpectResult == Result::NO_ERROR ||
748 kExpectResult == Result::OUTPUT_INSUFFICIENT_SIZE) {
749 // Only one output operand, hardcoded as index 0.
750 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
751 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
752 } else {
753 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), Result::BAD_STATE);
754 }
755 }
756 {
757 SCOPED_TRACE("burstCompute");
758
759 // TODO: If a burst API is added to nn::test_wrapper (e.g.,
760 // Execution::burstCompute()), then use that, rather than using
761 // Execution::setComputeMode() to make Execution::compute() use burst
762 // functionality.
763
764 auto oldComputeMode =
765 WrapperExecution::setComputeMode(WrapperExecution::ComputeMode::BURST);
766 base::ScopeGuard restore(
767 [oldComputeMode] { WrapperExecution::setComputeMode(oldComputeMode); });
768
769 WrapperExecution execution(&mCompilation);
770 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
771 TestPreparedModelLatest::pauseExecutions(true);
772 std::thread run([this, &execution] { EXPECT_EQ(execution.compute(), kExpectResult); });
773 getDimensionsWhileRunning(execution);
774 TestPreparedModelLatest::pauseExecutions(false);
775 run.join();
776 if (kExpectResult == Result::NO_ERROR) {
777 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
778 }
779 std::vector<uint32_t> dimensions;
780 if (kExpectResult == Result::NO_ERROR ||
781 kExpectResult == Result::OUTPUT_INSUFFICIENT_SIZE) {
782 // Only one output operand, hardcoded as index 0.
783 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
784 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
785 } else {
786 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), Result::BAD_STATE);
787 }
788 }
789 if (kExpectResult != Result::OUTPUT_INSUFFICIENT_SIZE) {
790 // computeWithDependencies doesn't support OUTPUT_INSUFFICIENT_SIZE
791 SCOPED_TRACE("computeWithDependencies");
792 WrapperExecution execution(&mCompilation);
793 ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
794 TestPreparedModelLatest::pauseExecutions(true);
795
796 WrapperEvent event;
797 // Note, due to the limitation of SampleDriver implementation, the call is synchronous.
798 // If the SampleDriver is updated to return real sync fence, this must be updated.
799 std::thread run([this, &execution, &event] {
800 EXPECT_EQ(execution.startComputeWithDependencies({}, 0, &event), kExpectResult);
801 });
802 getDimensionsWhileRunning(execution);
803 TestPreparedModelLatest::pauseExecutions(false);
804 run.join();
805 if (kExpectResult == Result::NO_ERROR) {
806 ASSERT_EQ(event.wait(), kExpectResult);
807 ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
808 } else {
809 ASSERT_EQ(event.wait(), Result::UNEXPECTED_NULL);
810 }
811 std::vector<uint32_t> dimensions;
812 if (kExpectResult == Result::NO_ERROR) {
813 // Only one output operand, hardcoded as index 0.
814 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), kExpectResult);
815 ASSERT_EQ(dimensions, kOutputDimensionsExpected);
816 } else {
817 ASSERT_EQ(execution.getOutputOperandDimensions(0, &dimensions), Result::BAD_STATE);
818 }
819 }
820 }
821
822 auto kTestValues = ::testing::Values(
823 std::make_tuple(ErrorStatus::NONE, Result::NO_ERROR, /* kUseIntrospectionAPI */ false),
824 std::make_tuple(ErrorStatus::DEVICE_UNAVAILABLE, Result::UNAVAILABLE_DEVICE,
825 /* kUseIntrospectionAPI */ false),
826 std::make_tuple(ErrorStatus::GENERAL_FAILURE, Result::OP_FAILED,
827 /* kUseIntrospectionAPI */ false),
828 std::make_tuple(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, Result::OUTPUT_INSUFFICIENT_SIZE,
829 /* kUseIntrospectionAPI */ false),
830 std::make_tuple(ErrorStatus::INVALID_ARGUMENT, Result::BAD_DATA,
831 /* kUseIntrospectionAPI */ false));
832
833 class ExecutionTest13 : public ExecutionTestTemplate<TestDriver13> {};
TEST_P(ExecutionTest13,Wait)834 TEST_P(ExecutionTest13, Wait) {
835 TestWait();
836 }
837 INSTANTIATE_TEST_CASE_P(Flavor, ExecutionTest13, kTestValues);
838
839 class ExecutionTest12 : public ExecutionTestTemplate<TestDriver12> {};
TEST_P(ExecutionTest12,Wait)840 TEST_P(ExecutionTest12, Wait) {
841 TestWait();
842 }
843 INSTANTIATE_TEST_CASE_P(Flavor, ExecutionTest12, kTestValues);
844
845 class ExecutionTest11 : public ExecutionTestTemplate<TestDriver11> {};
TEST_P(ExecutionTest11,Wait)846 TEST_P(ExecutionTest11, Wait) {
847 if (kForceErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
848 TestWait();
849 }
850 INSTANTIATE_TEST_CASE_P(Flavor, ExecutionTest11, kTestValues);
851
852 class ExecutionTest10 : public ExecutionTestTemplate<TestDriver10> {};
TEST_P(ExecutionTest10,Wait)853 TEST_P(ExecutionTest10, Wait) {
854 if (kForceErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) return;
855 TestWait();
856 }
857 INSTANTIATE_TEST_CASE_P(Flavor, ExecutionTest10, kTestValues);
858
859 auto kIntrospectionTestValues = ::testing::Values(
860 std::make_tuple(ErrorStatus::NONE, Result::NO_ERROR, /* kUseIntrospectionAPI */ true),
861 std::make_tuple(ErrorStatus::DEVICE_UNAVAILABLE, Result::UNAVAILABLE_DEVICE,
862 /* kUseIntrospectionAPI */ true),
863 std::make_tuple(ErrorStatus::GENERAL_FAILURE, Result::OP_FAILED,
864 /* kUseIntrospectionAPI */ true),
865 std::make_tuple(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, Result::OUTPUT_INSUFFICIENT_SIZE,
866 /* kUseIntrospectionAPI */ true),
867 std::make_tuple(ErrorStatus::INVALID_ARGUMENT, Result::BAD_DATA,
868 /* kUseIntrospectionAPI */ true));
869
870 INSTANTIATE_TEST_CASE_P(IntrospectionFlavor, ExecutionTest13, kIntrospectionTestValues);
871
872 } // namespace
873 } // namespace android
874