1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <gmock/gmock.h>
18 #include <nnapi/TypeUtils.h>
19 #include <nnapi/Types.h>
20 #include <nnapi/hal/ResilientDevice.h>
21 #include <tuple>
22 #include <utility>
23 #include "MockBuffer.h"
24 #include "MockDevice.h"
25 #include "MockPreparedModel.h"
26
27 namespace android::hardware::neuralnetworks::utils {
28 namespace {
29
30 using ::testing::_;
31 using ::testing::Return;
32
33 using SharedMockDevice = std::shared_ptr<const nn::MockDevice>;
34 using MockDeviceFactory = ::testing::MockFunction<nn::GeneralResult<nn::SharedDevice>(bool)>;
35
36 const std::string kName = "Google-MockV1";
37 const std::string kVersionString = "version1";
38 const auto kExtensions = std::vector<nn::Extension>{};
39 constexpr auto kNoInfo = std::numeric_limits<float>::max();
40 constexpr auto kNoPerformanceInfo =
41 nn::Capabilities::PerformanceInfo{.execTime = kNoInfo, .powerUsage = kNoInfo};
42 const auto kCapabilities = nn::Capabilities{
43 .relaxedFloat32toFloat16PerformanceScalar = kNoPerformanceInfo,
44 .relaxedFloat32toFloat16PerformanceTensor = kNoPerformanceInfo,
45 .operandPerformance = nn::Capabilities::OperandPerformanceTable::create({}).value(),
46 .ifPerformance = kNoPerformanceInfo,
47 .whilePerformance = kNoPerformanceInfo};
48 constexpr auto kNumberOfCacheFilesNeeded = std::pair<uint32_t, uint32_t>(5, 3);
49
createConfiguredMockDevice()50 SharedMockDevice createConfiguredMockDevice() {
51 auto mockDevice = std::make_shared<const nn::MockDevice>();
52
53 // Setup default actions for each relevant call.
54 constexpr auto getName_ret = []() -> const std::string& { return kName; };
55 constexpr auto getVersionString_ret = []() -> const std::string& { return kVersionString; };
56 constexpr auto kFeatureLevel = nn::kVersionFeatureLevel1;
57 constexpr auto kDeviceType = nn::DeviceType::ACCELERATOR;
58 constexpr auto getSupportedExtensions_ret = []() -> const std::vector<nn::Extension>& {
59 return kExtensions;
60 };
61 constexpr auto getCapabilities_ret = []() -> const nn::Capabilities& { return kCapabilities; };
62
63 // Setup default actions for each relevant call.
64 ON_CALL(*mockDevice, getName()).WillByDefault(getName_ret);
65 ON_CALL(*mockDevice, getVersionString()).WillByDefault(getVersionString_ret);
66 ON_CALL(*mockDevice, getFeatureLevel()).WillByDefault(Return(kFeatureLevel));
67 ON_CALL(*mockDevice, getType()).WillByDefault(Return(kDeviceType));
68 ON_CALL(*mockDevice, getSupportedExtensions()).WillByDefault(getSupportedExtensions_ret);
69 ON_CALL(*mockDevice, getCapabilities()).WillByDefault(getCapabilities_ret);
70 ON_CALL(*mockDevice, getNumberOfCacheFilesNeeded())
71 .WillByDefault(Return(kNumberOfCacheFilesNeeded));
72
73 // These EXPECT_CALL(...).Times(testing::AnyNumber()) calls are to suppress warnings on the
74 // uninteresting methods calls.
75 EXPECT_CALL(*mockDevice, getName()).Times(testing::AnyNumber());
76 EXPECT_CALL(*mockDevice, getVersionString()).Times(testing::AnyNumber());
77 EXPECT_CALL(*mockDevice, getFeatureLevel()).Times(testing::AnyNumber());
78 EXPECT_CALL(*mockDevice, getType()).Times(testing::AnyNumber());
79 EXPECT_CALL(*mockDevice, getSupportedExtensions()).Times(testing::AnyNumber());
80 EXPECT_CALL(*mockDevice, getCapabilities()).Times(testing::AnyNumber());
81 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded()).Times(testing::AnyNumber());
82
83 return mockDevice;
84 }
85
86 std::tuple<SharedMockDevice, std::unique_ptr<MockDeviceFactory>,
87 std::shared_ptr<const ResilientDevice>>
setup()88 setup() {
89 auto mockDevice = createConfiguredMockDevice();
90
91 auto mockDeviceFactory = std::make_unique<MockDeviceFactory>();
92 EXPECT_CALL(*mockDeviceFactory, Call(true)).Times(1).WillOnce(Return(mockDevice));
93
94 auto device = ResilientDevice::create(mockDeviceFactory->AsStdFunction()).value();
95 return std::make_tuple(std::move(mockDevice), std::move(mockDeviceFactory), std::move(device));
96 }
97
__anon636ef9190602(nn::ErrorStatus status) 98 constexpr auto makeError = [](nn::ErrorStatus status) {
99 return [status](const auto&... /*args*/) { return nn::error(status); };
100 };
101 const auto kReturnGeneralFailure = makeError(nn::ErrorStatus::GENERAL_FAILURE);
102 const auto kReturnDeadObject = makeError(nn::ErrorStatus::DEAD_OBJECT);
103
104 } // namespace
105
TEST(ResilientDeviceTest,invalidDeviceFactory)106 TEST(ResilientDeviceTest, invalidDeviceFactory) {
107 // setup call
108 const auto invalidDeviceFactory = ResilientDevice::Factory{};
109
110 // run test
111 const auto result = ResilientDevice::create(invalidDeviceFactory);
112
113 // verify result
114 ASSERT_FALSE(result.has_value());
115 EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
116 }
117
TEST(ResilientDeviceTest,preparedModelFactoryFailure)118 TEST(ResilientDeviceTest, preparedModelFactoryFailure) {
119 // setup call
120 const auto invalidDeviceFactory = kReturnGeneralFailure;
121
122 // run test
123 const auto result = ResilientDevice::create(invalidDeviceFactory);
124
125 // verify result
126 ASSERT_FALSE(result.has_value());
127 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
128 }
129
TEST(ResilientDeviceTest,cachedData)130 TEST(ResilientDeviceTest, cachedData) {
131 // setup call
132 const auto [mockDevice, mockDeviceFactory, device] = setup();
133
134 // run test and verify results
135 EXPECT_EQ(device->getName(), kName);
136 EXPECT_EQ(device->getVersionString(), kVersionString);
137 EXPECT_EQ(device->getSupportedExtensions(), kExtensions);
138 EXPECT_EQ(device->getCapabilities(), kCapabilities);
139 }
140
TEST(ResilientDeviceTest,getFeatureLevel)141 TEST(ResilientDeviceTest, getFeatureLevel) {
142 // setup call
143 const auto [mockDevice, mockDeviceFactory, device] = setup();
144 constexpr auto kFeatureLevel = nn::kVersionFeatureLevel1;
145 EXPECT_CALL(*mockDevice, getFeatureLevel()).Times(1).WillOnce(Return(kFeatureLevel));
146
147 // run test
148 const auto featureLevel = device->getFeatureLevel();
149
150 // verify results
151 EXPECT_EQ(featureLevel, kFeatureLevel);
152 }
153
TEST(ResilientDeviceTest,getType)154 TEST(ResilientDeviceTest, getType) {
155 // setup call
156 const auto [mockDevice, mockDeviceFactory, device] = setup();
157 constexpr auto kDeviceType = nn::DeviceType::ACCELERATOR;
158 EXPECT_CALL(*mockDevice, getType()).Times(1).WillOnce(Return(kDeviceType));
159
160 // run test
161 const auto type = device->getType();
162
163 // verify results
164 EXPECT_EQ(type, kDeviceType);
165 }
166
TEST(ResilientDeviceTest,getNumberOfCacheFilesNeeded)167 TEST(ResilientDeviceTest, getNumberOfCacheFilesNeeded) {
168 // setup call
169 const auto [mockDevice, mockDeviceFactory, device] = setup();
170 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded())
171 .Times(1)
172 .WillOnce(Return(kNumberOfCacheFilesNeeded));
173
174 // run test
175 const auto numberOfCacheFilesNeeded = device->getNumberOfCacheFilesNeeded();
176
177 // verify results
178 EXPECT_EQ(numberOfCacheFilesNeeded, kNumberOfCacheFilesNeeded);
179 }
180
TEST(ResilientDeviceTest,getDevice)181 TEST(ResilientDeviceTest, getDevice) {
182 // setup call
183 const auto [mockDevice, mockDeviceFactory, device] = setup();
184
185 // run test
186 const auto result = device->getDevice();
187
188 // verify result
189 EXPECT_TRUE(result == mockDevice);
190 }
191
TEST(ResilientDeviceTest,wait)192 TEST(ResilientDeviceTest, wait) {
193 // setup call
194 const auto [mockDevice, mockDeviceFactory, device] = setup();
195 EXPECT_CALL(*mockDevice, wait()).Times(1).WillOnce(Return(nn::GeneralResult<void>{}));
196
197 // run test
198 const auto result = device->wait();
199
200 // verify result
201 ASSERT_TRUE(result.has_value())
202 << "Failed with " << result.error().code << ": " << result.error().message;
203 }
204
TEST(ResilientDeviceTest,waitError)205 TEST(ResilientDeviceTest, waitError) {
206 // setup call
207 const auto [mockDevice, mockDeviceFactory, device] = setup();
208 EXPECT_CALL(*mockDevice, wait()).Times(1).WillOnce(kReturnGeneralFailure);
209
210 // run test
211 const auto result = device->wait();
212
213 // verify result
214 ASSERT_FALSE(result.has_value());
215 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
216 }
217
TEST(ResilientDeviceTest,waitDeadObjectFailedRecovery)218 TEST(ResilientDeviceTest, waitDeadObjectFailedRecovery) {
219 // setup call
220 const auto [mockDevice, mockDeviceFactory, device] = setup();
221 EXPECT_CALL(*mockDevice, wait()).Times(1).WillOnce(kReturnDeadObject);
222 EXPECT_CALL(*mockDeviceFactory, Call(true)).Times(1).WillOnce(kReturnGeneralFailure);
223
224 // run test
225 const auto result = device->wait();
226
227 // verify result
228 ASSERT_FALSE(result.has_value());
229 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
230 }
231
TEST(ResilientDeviceTest,waitDeadObjectSuccessfulRecovery)232 TEST(ResilientDeviceTest, waitDeadObjectSuccessfulRecovery) {
233 // setup call
234 const auto [mockDevice, mockDeviceFactory, device] = setup();
235 EXPECT_CALL(*mockDevice, wait()).Times(1).WillOnce(kReturnDeadObject);
236 const auto recoveredMockDevice = createConfiguredMockDevice();
237 EXPECT_CALL(*recoveredMockDevice, wait()).Times(1).WillOnce(Return(nn::GeneralResult<void>{}));
238 EXPECT_CALL(*mockDeviceFactory, Call(true)).Times(1).WillOnce(Return(recoveredMockDevice));
239
240 // run test
241 const auto result = device->wait();
242
243 // verify result
244 ASSERT_TRUE(result.has_value())
245 << "Failed with " << result.error().code << ": " << result.error().message;
246 }
247
TEST(ResilientDeviceTest,getSupportedOperations)248 TEST(ResilientDeviceTest, getSupportedOperations) {
249 // setup call
250 const auto [mockDevice, mockDeviceFactory, device] = setup();
251 EXPECT_CALL(*mockDevice, getSupportedOperations(_))
252 .Times(1)
253 .WillOnce(Return(nn::GeneralResult<std::vector<bool>>{}));
254
255 // run test
256 const auto result = device->getSupportedOperations({});
257
258 // verify result
259 ASSERT_TRUE(result.has_value())
260 << "Failed with " << result.error().code << ": " << result.error().message;
261 }
262
TEST(ResilientDeviceTest,getSupportedOperationsError)263 TEST(ResilientDeviceTest, getSupportedOperationsError) {
264 // setup call
265 const auto [mockDevice, mockDeviceFactory, device] = setup();
266 EXPECT_CALL(*mockDevice, getSupportedOperations(_)).Times(1).WillOnce(kReturnGeneralFailure);
267
268 // run test
269 const auto result = device->getSupportedOperations({});
270
271 // verify result
272 ASSERT_FALSE(result.has_value());
273 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
274 }
275
TEST(ResilientDeviceTest,getSupportedOperationsDeadObjectFailedRecovery)276 TEST(ResilientDeviceTest, getSupportedOperationsDeadObjectFailedRecovery) {
277 // setup call
278 const auto [mockDevice, mockDeviceFactory, device] = setup();
279 EXPECT_CALL(*mockDevice, getSupportedOperations(_)).Times(1).WillOnce(kReturnDeadObject);
280 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
281
282 // run test
283 const auto result = device->getSupportedOperations({});
284
285 // verify result
286 ASSERT_FALSE(result.has_value());
287 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
288 }
289
TEST(ResilientDeviceTest,getSupportedOperationsDeadObjectSuccessfulRecovery)290 TEST(ResilientDeviceTest, getSupportedOperationsDeadObjectSuccessfulRecovery) {
291 // setup call
292 const auto [mockDevice, mockDeviceFactory, device] = setup();
293 EXPECT_CALL(*mockDevice, getSupportedOperations(_)).Times(1).WillOnce(kReturnDeadObject);
294 const auto recoveredMockDevice = createConfiguredMockDevice();
295 EXPECT_CALL(*recoveredMockDevice, getSupportedOperations(_))
296 .Times(1)
297 .WillOnce(Return(nn::GeneralResult<std::vector<bool>>{}));
298 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
299
300 // run test
301 const auto result = device->getSupportedOperations({});
302
303 // verify result
304 ASSERT_TRUE(result.has_value())
305 << "Failed with " << result.error().code << ": " << result.error().message;
306 }
307
TEST(ResilientDeviceTest,prepareModel)308 TEST(ResilientDeviceTest, prepareModel) {
309 // setup call
310 const auto [mockDevice, mockDeviceFactory, device] = setup();
311 const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
312 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
313 .Times(1)
314 .WillOnce(Return(mockPreparedModel));
315
316 // run test
317 const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
318
319 // verify result
320 ASSERT_TRUE(result.has_value())
321 << "Failed with " << result.error().code << ": " << result.error().message;
322 }
323
TEST(ResilientDeviceTest,prepareModelError)324 TEST(ResilientDeviceTest, prepareModelError) {
325 // setup call
326 const auto [mockDevice, mockDeviceFactory, device] = setup();
327 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
328 .Times(1)
329 .WillOnce(kReturnGeneralFailure);
330
331 // run test
332 const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
333
334 // verify result
335 ASSERT_FALSE(result.has_value());
336 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
337 }
338
TEST(ResilientDeviceTest,prepareModelDeadObjectFailedRecovery)339 TEST(ResilientDeviceTest, prepareModelDeadObjectFailedRecovery) {
340 // setup call
341 const auto [mockDevice, mockDeviceFactory, device] = setup();
342 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
343 .Times(1)
344 .WillOnce(kReturnDeadObject);
345 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
346
347 // run test
348 const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
349
350 // verify result
351 ASSERT_FALSE(result.has_value());
352 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
353 }
354
TEST(ResilientDeviceTest,prepareModelDeadObjectSuccessfulRecovery)355 TEST(ResilientDeviceTest, prepareModelDeadObjectSuccessfulRecovery) {
356 // setup call
357 const auto [mockDevice, mockDeviceFactory, device] = setup();
358 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
359 .Times(1)
360 .WillOnce(kReturnDeadObject);
361 const auto recoveredMockDevice = createConfiguredMockDevice();
362 const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
363 EXPECT_CALL(*recoveredMockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
364 .Times(1)
365 .WillOnce(Return(mockPreparedModel));
366 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
367
368 // run test
369 const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
370
371 // verify result
372 ASSERT_TRUE(result.has_value())
373 << "Failed with " << result.error().code << ": " << result.error().message;
374 }
375
TEST(ResilientDeviceTest,prepareModelFromCache)376 TEST(ResilientDeviceTest, prepareModelFromCache) {
377 // setup call
378 const auto [mockDevice, mockDeviceFactory, device] = setup();
379 const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
380 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
381 .Times(1)
382 .WillOnce(Return(mockPreparedModel));
383
384 // run test
385 const auto result = device->prepareModelFromCache({}, {}, {}, {});
386
387 // verify result
388 ASSERT_TRUE(result.has_value())
389 << "Failed with " << result.error().code << ": " << result.error().message;
390 }
391
TEST(ResilientDeviceTest,prepareModelFromCacheError)392 TEST(ResilientDeviceTest, prepareModelFromCacheError) {
393 // setup call
394 const auto [mockDevice, mockDeviceFactory, device] = setup();
395 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
396 .Times(1)
397 .WillOnce(kReturnGeneralFailure);
398
399 // run test
400 const auto result = device->prepareModelFromCache({}, {}, {}, {});
401
402 // verify result
403 ASSERT_FALSE(result.has_value());
404 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
405 }
406
TEST(ResilientDeviceTest,prepareModelFromCacheDeadObjectFailedRecovery)407 TEST(ResilientDeviceTest, prepareModelFromCacheDeadObjectFailedRecovery) {
408 // setup call
409 const auto [mockDevice, mockDeviceFactory, device] = setup();
410 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
411 .Times(1)
412 .WillOnce(kReturnDeadObject);
413 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
414
415 // run test
416 const auto result = device->prepareModelFromCache({}, {}, {}, {});
417
418 // verify result
419 ASSERT_FALSE(result.has_value());
420 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
421 }
422
TEST(ResilientDeviceTest,prepareModelFromCacheDeadObjectSuccessfulRecovery)423 TEST(ResilientDeviceTest, prepareModelFromCacheDeadObjectSuccessfulRecovery) {
424 // setup call
425 const auto [mockDevice, mockDeviceFactory, device] = setup();
426 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
427 .Times(1)
428 .WillOnce(kReturnDeadObject);
429 const auto recoveredMockDevice = createConfiguredMockDevice();
430 const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
431 EXPECT_CALL(*recoveredMockDevice, prepareModelFromCache(_, _, _, _))
432 .Times(1)
433 .WillOnce(Return(mockPreparedModel));
434 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
435
436 // run test
437 const auto result = device->prepareModelFromCache({}, {}, {}, {});
438
439 // verify result
440 ASSERT_TRUE(result.has_value())
441 << "Failed with " << result.error().code << ": " << result.error().message;
442 }
443
TEST(ResilientDeviceTest,allocate)444 TEST(ResilientDeviceTest, allocate) {
445 // setup call
446 const auto [mockDevice, mockDeviceFactory, device] = setup();
447 const auto mockBuffer = std::make_shared<const nn::MockBuffer>();
448 EXPECT_CALL(*mockDevice, allocate(_, _, _, _)).Times(1).WillOnce(Return(mockBuffer));
449
450 // run test
451 const auto result = device->allocate({}, {}, {}, {});
452
453 // verify result
454 ASSERT_TRUE(result.has_value())
455 << "Failed with " << result.error().code << ": " << result.error().message;
456 }
457
TEST(ResilientDeviceTest,allocateError)458 TEST(ResilientDeviceTest, allocateError) {
459 // setup call
460 const auto [mockDevice, mockDeviceFactory, device] = setup();
461 EXPECT_CALL(*mockDevice, allocate(_, _, _, _)).Times(1).WillOnce(kReturnGeneralFailure);
462
463 // run test
464 const auto result = device->allocate({}, {}, {}, {});
465
466 // verify result
467 ASSERT_FALSE(result.has_value());
468 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
469 }
470
TEST(ResilientDeviceTest,allocateDeadObjectFailedRecovery)471 TEST(ResilientDeviceTest, allocateDeadObjectFailedRecovery) {
472 // setup call
473 const auto [mockDevice, mockDeviceFactory, device] = setup();
474 EXPECT_CALL(*mockDevice, allocate(_, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
475 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
476
477 // run test
478 const auto result = device->allocate({}, {}, {}, {});
479
480 // verify result
481 ASSERT_FALSE(result.has_value());
482 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
483 }
484
TEST(ResilientDeviceTest,allocateDeadObjectSuccessfulRecovery)485 TEST(ResilientDeviceTest, allocateDeadObjectSuccessfulRecovery) {
486 // setup call
487 const auto [mockDevice, mockDeviceFactory, device] = setup();
488 EXPECT_CALL(*mockDevice, allocate(_, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
489 const auto recoveredMockDevice = createConfiguredMockDevice();
490 const auto mockBuffer = std::make_shared<const nn::MockBuffer>();
491 EXPECT_CALL(*recoveredMockDevice, allocate(_, _, _, _)).Times(1).WillOnce(Return(mockBuffer));
492 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
493
494 // run test
495 const auto result = device->allocate({}, {}, {}, {});
496
497 // verify result
498 ASSERT_TRUE(result.has_value())
499 << "Failed with " << result.error().code << ": " << result.error().message;
500 }
501
TEST(ResilientDeviceTest,recover)502 TEST(ResilientDeviceTest, recover) {
503 // setup call
504 const auto [mockDevice, mockDeviceFactory, device] = setup();
505 const auto recoveredMockDevice = createConfiguredMockDevice();
506 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
507
508 // run test
509 const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
510
511 // verify result
512 ASSERT_TRUE(result.has_value())
513 << "Failed with " << result.error().code << ": " << result.error().message;
514 EXPECT_TRUE(result.value() == recoveredMockDevice);
515 }
516
TEST(ResilientDeviceTest,recoverFailure)517 TEST(ResilientDeviceTest, recoverFailure) {
518 // setup call
519 const auto [mockDevice, mockDeviceFactory, device] = setup();
520 const auto recoveredMockDevice = createConfiguredMockDevice();
521 EXPECT_CALL(*mockDeviceFactory, Call(_)).Times(1).WillOnce(kReturnGeneralFailure);
522
523 // run test
524 const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
525
526 // verify result
527 EXPECT_FALSE(result.has_value());
528 }
529
TEST(ResilientDeviceTest,someoneElseRecovered)530 TEST(ResilientDeviceTest, someoneElseRecovered) {
531 // setup call
532 const auto [mockDevice, mockDeviceFactory, device] = setup();
533 const auto recoveredMockDevice = createConfiguredMockDevice();
534 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
535 device->recover(mockDevice.get(), /*blocking=*/false);
536
537 // run test
538 const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
539
540 // verify result
541 ASSERT_TRUE(result.has_value())
542 << "Failed with " << result.error().code << ": " << result.error().message;
543 EXPECT_TRUE(result.value() == recoveredMockDevice);
544 }
545
TEST(ResilientDeviceTest,recoverCacheMismatchGetName)546 TEST(ResilientDeviceTest, recoverCacheMismatchGetName) {
547 // setup call
548 const auto [mockDevice, mockDeviceFactory, device] = setup();
549 const auto recoveredMockDevice = createConfiguredMockDevice();
550 const std::string kDifferentName = "Google-DifferentName";
551 const auto ret = [&kDifferentName]() -> const std::string& { return kDifferentName; };
552 EXPECT_CALL(*recoveredMockDevice, getName()).Times(1).WillOnce(ret);
553 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
554
555 // run test
556 const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
557
558 // verify result
559 ASSERT_TRUE(result.has_value())
560 << "Failed with " << result.error().code << ": " << result.error().message;
561 EXPECT_TRUE(result.value() != nullptr);
562 EXPECT_TRUE(result.value() != mockDevice);
563 EXPECT_TRUE(result.value() != recoveredMockDevice);
564 }
565
TEST(ResilientDeviceTest,recoverCacheMismatchGetVersionString)566 TEST(ResilientDeviceTest, recoverCacheMismatchGetVersionString) {
567 // setup call
568 const auto [mockDevice, mockDeviceFactory, device] = setup();
569 const auto recoveredMockDevice = createConfiguredMockDevice();
570 const std::string kDifferentVersionString = "differentversion";
571 const auto ret = [&kDifferentVersionString]() -> const std::string& {
572 return kDifferentVersionString;
573 };
574 EXPECT_CALL(*recoveredMockDevice, getVersionString()).Times(1).WillOnce(ret);
575 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
576
577 // run test
578 const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
579
580 // verify result
581 ASSERT_TRUE(result.has_value())
582 << "Failed with " << result.error().code << ": " << result.error().message;
583 EXPECT_TRUE(result.value() != nullptr);
584 EXPECT_TRUE(result.value() != mockDevice);
585 EXPECT_TRUE(result.value() != recoveredMockDevice);
586 }
587
TEST(ResilientDeviceTest,recoverCacheMismatchGetFeatureLevel)588 TEST(ResilientDeviceTest, recoverCacheMismatchGetFeatureLevel) {
589 // setup call
590 const auto [mockDevice, mockDeviceFactory, device] = setup();
591 const auto recoveredMockDevice = createConfiguredMockDevice();
592 EXPECT_CALL(*recoveredMockDevice, getFeatureLevel())
593 .Times(1)
594 .WillOnce(Return(nn::kVersionFeatureLevel2));
595 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
596
597 // run test
598 const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
599
600 // verify result
601 ASSERT_TRUE(result.has_value())
602 << "Failed with " << result.error().code << ": " << result.error().message;
603 EXPECT_TRUE(result.value() != nullptr);
604 EXPECT_TRUE(result.value() != mockDevice);
605 EXPECT_TRUE(result.value() != recoveredMockDevice);
606 }
607
TEST(ResilientDeviceTest,recoverCacheMismatchGetType)608 TEST(ResilientDeviceTest, recoverCacheMismatchGetType) {
609 // setup call
610 const auto [mockDevice, mockDeviceFactory, device] = setup();
611 const auto recoveredMockDevice = createConfiguredMockDevice();
612 EXPECT_CALL(*recoveredMockDevice, getType()).Times(1).WillOnce(Return(nn::DeviceType::GPU));
613 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
614
615 // run test
616 const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
617
618 // verify result
619 ASSERT_TRUE(result.has_value())
620 << "Failed with " << result.error().code << ": " << result.error().message;
621 EXPECT_TRUE(result.value() != nullptr);
622 EXPECT_TRUE(result.value() != mockDevice);
623 EXPECT_TRUE(result.value() != recoveredMockDevice);
624 }
625
TEST(ResilientDeviceTest,recoverCacheMismatchGetSupportedExtensions)626 TEST(ResilientDeviceTest, recoverCacheMismatchGetSupportedExtensions) {
627 // setup call
628 const auto [mockDevice, mockDeviceFactory, device] = setup();
629 const auto recoveredMockDevice = createConfiguredMockDevice();
630 const auto kDifferentExtensions =
631 std::vector<nn::Extension>{nn::Extension{.name = "", .operandTypes = {}}};
632 const auto ret = [&kDifferentExtensions]() -> const std::vector<nn::Extension>& {
633 return kDifferentExtensions;
634 };
635 EXPECT_CALL(*recoveredMockDevice, getSupportedExtensions()).Times(1).WillOnce(ret);
636 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
637
638 // run test
639 const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
640
641 // verify result
642 ASSERT_TRUE(result.has_value())
643 << "Failed with " << result.error().code << ": " << result.error().message;
644 EXPECT_TRUE(result.value() != nullptr);
645 EXPECT_TRUE(result.value() != mockDevice);
646 EXPECT_TRUE(result.value() != recoveredMockDevice);
647 }
648
TEST(ResilientDeviceTest,recoverCacheMismatchGetCapabilities)649 TEST(ResilientDeviceTest, recoverCacheMismatchGetCapabilities) {
650 // setup call
651 const auto [mockDevice, mockDeviceFactory, device] = setup();
652 const auto recoveredMockDevice = createConfiguredMockDevice();
653 const auto kDifferentCapabilities = nn::Capabilities{
654 .relaxedFloat32toFloat16PerformanceTensor = {.execTime = 0.5f, .powerUsage = 0.5f},
655 .operandPerformance = nn::Capabilities::OperandPerformanceTable::create({}).value()};
656 const auto ret = [&kDifferentCapabilities]() -> const nn::Capabilities& {
657 return kDifferentCapabilities;
658 };
659 EXPECT_CALL(*recoveredMockDevice, getCapabilities()).Times(1).WillOnce(ret);
660 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
661
662 // run test
663 const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
664
665 // verify result
666 ASSERT_TRUE(result.has_value())
667 << "Failed with " << result.error().code << ": " << result.error().message;
668 EXPECT_TRUE(result.value() != nullptr);
669 EXPECT_TRUE(result.value() != mockDevice);
670 EXPECT_TRUE(result.value() != recoveredMockDevice);
671 }
672
TEST(ResilientDeviceTest,recoverCacheMismatchInvalidPrepareModel)673 TEST(ResilientDeviceTest, recoverCacheMismatchInvalidPrepareModel) {
674 // setup call
675 const auto [mockDevice, mockDeviceFactory, device] = setup();
676 const auto recoveredMockDevice = createConfiguredMockDevice();
677 EXPECT_CALL(*recoveredMockDevice, getType()).Times(1).WillOnce(Return(nn::DeviceType::GPU));
678 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
679 device->recover(mockDevice.get(), /*blocking=*/false);
680
681 // run test
682 auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
683
684 // verify result
685 ASSERT_TRUE(result.has_value())
686 << "Failed with " << result.error().code << ": " << result.error().message;
687 EXPECT_TRUE(result.value() != nullptr);
688 }
689
TEST(ResilientDeviceTest,recoverCacheMismatchInvalidPrepareModelFromCache)690 TEST(ResilientDeviceTest, recoverCacheMismatchInvalidPrepareModelFromCache) {
691 // setup call
692 const auto [mockDevice, mockDeviceFactory, device] = setup();
693 const auto recoveredMockDevice = createConfiguredMockDevice();
694 EXPECT_CALL(*recoveredMockDevice, getType()).Times(1).WillOnce(Return(nn::DeviceType::GPU));
695 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
696 device->recover(mockDevice.get(), /*blocking=*/false);
697
698 // run test
699 auto result = device->prepareModelFromCache({}, {}, {}, {});
700
701 // verify result
702 ASSERT_TRUE(result.has_value())
703 << "Failed with " << result.error().code << ": " << result.error().message;
704 EXPECT_TRUE(result.value() != nullptr);
705 }
706
TEST(ResilientDeviceTest,recoverCacheMismatchInvalidAllocate)707 TEST(ResilientDeviceTest, recoverCacheMismatchInvalidAllocate) {
708 // setup call
709 const auto [mockDevice, mockDeviceFactory, device] = setup();
710 const auto recoveredMockDevice = createConfiguredMockDevice();
711 EXPECT_CALL(*recoveredMockDevice, getType()).Times(1).WillOnce(Return(nn::DeviceType::GPU));
712 EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
713 device->recover(mockDevice.get(), /*blocking=*/false);
714
715 // run test
716 auto result = device->allocate({}, {}, {}, {});
717
718 // verify result
719 ASSERT_TRUE(result.has_value())
720 << "Failed with " << result.error().code << ": " << result.error().message;
721 EXPECT_TRUE(result.value() != nullptr);
722 }
723
724 } // namespace android::hardware::neuralnetworks::utils
725