1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <gmock/gmock.h>
18 #include <nnapi/TypeUtils.h>
19 #include <nnapi/Types.h>
20 #include <nnapi/hal/ResilientDevice.h>
21 #include <tuple>
22 #include <utility>
23 #include "MockBuffer.h"
24 #include "MockDevice.h"
25 #include "MockPreparedModel.h"
26 
27 namespace android::hardware::neuralnetworks::utils {
28 namespace {
29 
30 using ::testing::_;
31 using ::testing::Return;
32 
33 using SharedMockDevice = std::shared_ptr<const nn::MockDevice>;
34 using MockDeviceFactory = ::testing::MockFunction<nn::GeneralResult<nn::SharedDevice>(bool)>;
35 
36 const std::string kName = "Google-MockV1";
37 const std::string kVersionString = "version1";
38 const auto kExtensions = std::vector<nn::Extension>{};
39 constexpr auto kNoInfo = std::numeric_limits<float>::max();
40 constexpr auto kNoPerformanceInfo =
41         nn::Capabilities::PerformanceInfo{.execTime = kNoInfo, .powerUsage = kNoInfo};
42 const auto kCapabilities = nn::Capabilities{
43         .relaxedFloat32toFloat16PerformanceScalar = kNoPerformanceInfo,
44         .relaxedFloat32toFloat16PerformanceTensor = kNoPerformanceInfo,
45         .operandPerformance = nn::Capabilities::OperandPerformanceTable::create({}).value(),
46         .ifPerformance = kNoPerformanceInfo,
47         .whilePerformance = kNoPerformanceInfo};
48 constexpr auto kNumberOfCacheFilesNeeded = std::pair<uint32_t, uint32_t>(5, 3);
49 
createConfiguredMockDevice()50 SharedMockDevice createConfiguredMockDevice() {
51     auto mockDevice = std::make_shared<const nn::MockDevice>();
52 
53     // Setup default actions for each relevant call.
54     constexpr auto getName_ret = []() -> const std::string& { return kName; };
55     constexpr auto getVersionString_ret = []() -> const std::string& { return kVersionString; };
56     constexpr auto kFeatureLevel = nn::kVersionFeatureLevel1;
57     constexpr auto kDeviceType = nn::DeviceType::ACCELERATOR;
58     constexpr auto getSupportedExtensions_ret = []() -> const std::vector<nn::Extension>& {
59         return kExtensions;
60     };
61     constexpr auto getCapabilities_ret = []() -> const nn::Capabilities& { return kCapabilities; };
62 
63     // Setup default actions for each relevant call.
64     ON_CALL(*mockDevice, getName()).WillByDefault(getName_ret);
65     ON_CALL(*mockDevice, getVersionString()).WillByDefault(getVersionString_ret);
66     ON_CALL(*mockDevice, getFeatureLevel()).WillByDefault(Return(kFeatureLevel));
67     ON_CALL(*mockDevice, getType()).WillByDefault(Return(kDeviceType));
68     ON_CALL(*mockDevice, getSupportedExtensions()).WillByDefault(getSupportedExtensions_ret);
69     ON_CALL(*mockDevice, getCapabilities()).WillByDefault(getCapabilities_ret);
70     ON_CALL(*mockDevice, getNumberOfCacheFilesNeeded())
71             .WillByDefault(Return(kNumberOfCacheFilesNeeded));
72 
73     // These EXPECT_CALL(...).Times(testing::AnyNumber()) calls are to suppress warnings on the
74     // uninteresting methods calls.
75     EXPECT_CALL(*mockDevice, getName()).Times(testing::AnyNumber());
76     EXPECT_CALL(*mockDevice, getVersionString()).Times(testing::AnyNumber());
77     EXPECT_CALL(*mockDevice, getFeatureLevel()).Times(testing::AnyNumber());
78     EXPECT_CALL(*mockDevice, getType()).Times(testing::AnyNumber());
79     EXPECT_CALL(*mockDevice, getSupportedExtensions()).Times(testing::AnyNumber());
80     EXPECT_CALL(*mockDevice, getCapabilities()).Times(testing::AnyNumber());
81     EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded()).Times(testing::AnyNumber());
82 
83     return mockDevice;
84 }
85 
86 std::tuple<SharedMockDevice, std::unique_ptr<MockDeviceFactory>,
87            std::shared_ptr<const ResilientDevice>>
setup()88 setup() {
89     auto mockDevice = createConfiguredMockDevice();
90 
91     auto mockDeviceFactory = std::make_unique<MockDeviceFactory>();
92     EXPECT_CALL(*mockDeviceFactory, Call(true)).Times(1).WillOnce(Return(mockDevice));
93 
94     auto device = ResilientDevice::create(mockDeviceFactory->AsStdFunction()).value();
95     return std::make_tuple(std::move(mockDevice), std::move(mockDeviceFactory), std::move(device));
96 }
97 
__anon636ef9190602(nn::ErrorStatus status) 98 constexpr auto makeError = [](nn::ErrorStatus status) {
99     return [status](const auto&... /*args*/) { return nn::error(status); };
100 };
101 const auto kReturnGeneralFailure = makeError(nn::ErrorStatus::GENERAL_FAILURE);
102 const auto kReturnDeadObject = makeError(nn::ErrorStatus::DEAD_OBJECT);
103 
104 }  // namespace
105 
TEST(ResilientDeviceTest,invalidDeviceFactory)106 TEST(ResilientDeviceTest, invalidDeviceFactory) {
107     // setup call
108     const auto invalidDeviceFactory = ResilientDevice::Factory{};
109 
110     // run test
111     const auto result = ResilientDevice::create(invalidDeviceFactory);
112 
113     // verify result
114     ASSERT_FALSE(result.has_value());
115     EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
116 }
117 
TEST(ResilientDeviceTest,preparedModelFactoryFailure)118 TEST(ResilientDeviceTest, preparedModelFactoryFailure) {
119     // setup call
120     const auto invalidDeviceFactory = kReturnGeneralFailure;
121 
122     // run test
123     const auto result = ResilientDevice::create(invalidDeviceFactory);
124 
125     // verify result
126     ASSERT_FALSE(result.has_value());
127     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
128 }
129 
TEST(ResilientDeviceTest,cachedData)130 TEST(ResilientDeviceTest, cachedData) {
131     // setup call
132     const auto [mockDevice, mockDeviceFactory, device] = setup();
133 
134     // run test and verify results
135     EXPECT_EQ(device->getName(), kName);
136     EXPECT_EQ(device->getVersionString(), kVersionString);
137     EXPECT_EQ(device->getSupportedExtensions(), kExtensions);
138     EXPECT_EQ(device->getCapabilities(), kCapabilities);
139 }
140 
TEST(ResilientDeviceTest,getFeatureLevel)141 TEST(ResilientDeviceTest, getFeatureLevel) {
142     // setup call
143     const auto [mockDevice, mockDeviceFactory, device] = setup();
144     constexpr auto kFeatureLevel = nn::kVersionFeatureLevel1;
145     EXPECT_CALL(*mockDevice, getFeatureLevel()).Times(1).WillOnce(Return(kFeatureLevel));
146 
147     // run test
148     const auto featureLevel = device->getFeatureLevel();
149 
150     // verify results
151     EXPECT_EQ(featureLevel, kFeatureLevel);
152 }
153 
TEST(ResilientDeviceTest,getType)154 TEST(ResilientDeviceTest, getType) {
155     // setup call
156     const auto [mockDevice, mockDeviceFactory, device] = setup();
157     constexpr auto kDeviceType = nn::DeviceType::ACCELERATOR;
158     EXPECT_CALL(*mockDevice, getType()).Times(1).WillOnce(Return(kDeviceType));
159 
160     // run test
161     const auto type = device->getType();
162 
163     // verify results
164     EXPECT_EQ(type, kDeviceType);
165 }
166 
TEST(ResilientDeviceTest,getNumberOfCacheFilesNeeded)167 TEST(ResilientDeviceTest, getNumberOfCacheFilesNeeded) {
168     // setup call
169     const auto [mockDevice, mockDeviceFactory, device] = setup();
170     EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded())
171             .Times(1)
172             .WillOnce(Return(kNumberOfCacheFilesNeeded));
173 
174     // run test
175     const auto numberOfCacheFilesNeeded = device->getNumberOfCacheFilesNeeded();
176 
177     // verify results
178     EXPECT_EQ(numberOfCacheFilesNeeded, kNumberOfCacheFilesNeeded);
179 }
180 
TEST(ResilientDeviceTest,getDevice)181 TEST(ResilientDeviceTest, getDevice) {
182     // setup call
183     const auto [mockDevice, mockDeviceFactory, device] = setup();
184 
185     // run test
186     const auto result = device->getDevice();
187 
188     // verify result
189     EXPECT_TRUE(result == mockDevice);
190 }
191 
TEST(ResilientDeviceTest,wait)192 TEST(ResilientDeviceTest, wait) {
193     // setup call
194     const auto [mockDevice, mockDeviceFactory, device] = setup();
195     EXPECT_CALL(*mockDevice, wait()).Times(1).WillOnce(Return(nn::GeneralResult<void>{}));
196 
197     // run test
198     const auto result = device->wait();
199 
200     // verify result
201     ASSERT_TRUE(result.has_value())
202             << "Failed with " << result.error().code << ": " << result.error().message;
203 }
204 
TEST(ResilientDeviceTest,waitError)205 TEST(ResilientDeviceTest, waitError) {
206     // setup call
207     const auto [mockDevice, mockDeviceFactory, device] = setup();
208     EXPECT_CALL(*mockDevice, wait()).Times(1).WillOnce(kReturnGeneralFailure);
209 
210     // run test
211     const auto result = device->wait();
212 
213     // verify result
214     ASSERT_FALSE(result.has_value());
215     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
216 }
217 
TEST(ResilientDeviceTest,waitDeadObjectFailedRecovery)218 TEST(ResilientDeviceTest, waitDeadObjectFailedRecovery) {
219     // setup call
220     const auto [mockDevice, mockDeviceFactory, device] = setup();
221     EXPECT_CALL(*mockDevice, wait()).Times(1).WillOnce(kReturnDeadObject);
222     EXPECT_CALL(*mockDeviceFactory, Call(true)).Times(1).WillOnce(kReturnGeneralFailure);
223 
224     // run test
225     const auto result = device->wait();
226 
227     // verify result
228     ASSERT_FALSE(result.has_value());
229     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
230 }
231 
TEST(ResilientDeviceTest,waitDeadObjectSuccessfulRecovery)232 TEST(ResilientDeviceTest, waitDeadObjectSuccessfulRecovery) {
233     // setup call
234     const auto [mockDevice, mockDeviceFactory, device] = setup();
235     EXPECT_CALL(*mockDevice, wait()).Times(1).WillOnce(kReturnDeadObject);
236     const auto recoveredMockDevice = createConfiguredMockDevice();
237     EXPECT_CALL(*recoveredMockDevice, wait()).Times(1).WillOnce(Return(nn::GeneralResult<void>{}));
238     EXPECT_CALL(*mockDeviceFactory, Call(true)).Times(1).WillOnce(Return(recoveredMockDevice));
239 
240     // run test
241     const auto result = device->wait();
242 
243     // verify result
244     ASSERT_TRUE(result.has_value())
245             << "Failed with " << result.error().code << ": " << result.error().message;
246 }
247 
TEST(ResilientDeviceTest,getSupportedOperations)248 TEST(ResilientDeviceTest, getSupportedOperations) {
249     // setup call
250     const auto [mockDevice, mockDeviceFactory, device] = setup();
251     EXPECT_CALL(*mockDevice, getSupportedOperations(_))
252             .Times(1)
253             .WillOnce(Return(nn::GeneralResult<std::vector<bool>>{}));
254 
255     // run test
256     const auto result = device->getSupportedOperations({});
257 
258     // verify result
259     ASSERT_TRUE(result.has_value())
260             << "Failed with " << result.error().code << ": " << result.error().message;
261 }
262 
TEST(ResilientDeviceTest,getSupportedOperationsError)263 TEST(ResilientDeviceTest, getSupportedOperationsError) {
264     // setup call
265     const auto [mockDevice, mockDeviceFactory, device] = setup();
266     EXPECT_CALL(*mockDevice, getSupportedOperations(_)).Times(1).WillOnce(kReturnGeneralFailure);
267 
268     // run test
269     const auto result = device->getSupportedOperations({});
270 
271     // verify result
272     ASSERT_FALSE(result.has_value());
273     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
274 }
275 
TEST(ResilientDeviceTest,getSupportedOperationsDeadObjectFailedRecovery)276 TEST(ResilientDeviceTest, getSupportedOperationsDeadObjectFailedRecovery) {
277     // setup call
278     const auto [mockDevice, mockDeviceFactory, device] = setup();
279     EXPECT_CALL(*mockDevice, getSupportedOperations(_)).Times(1).WillOnce(kReturnDeadObject);
280     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
281 
282     // run test
283     const auto result = device->getSupportedOperations({});
284 
285     // verify result
286     ASSERT_FALSE(result.has_value());
287     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
288 }
289 
TEST(ResilientDeviceTest,getSupportedOperationsDeadObjectSuccessfulRecovery)290 TEST(ResilientDeviceTest, getSupportedOperationsDeadObjectSuccessfulRecovery) {
291     // setup call
292     const auto [mockDevice, mockDeviceFactory, device] = setup();
293     EXPECT_CALL(*mockDevice, getSupportedOperations(_)).Times(1).WillOnce(kReturnDeadObject);
294     const auto recoveredMockDevice = createConfiguredMockDevice();
295     EXPECT_CALL(*recoveredMockDevice, getSupportedOperations(_))
296             .Times(1)
297             .WillOnce(Return(nn::GeneralResult<std::vector<bool>>{}));
298     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
299 
300     // run test
301     const auto result = device->getSupportedOperations({});
302 
303     // verify result
304     ASSERT_TRUE(result.has_value())
305             << "Failed with " << result.error().code << ": " << result.error().message;
306 }
307 
TEST(ResilientDeviceTest,prepareModel)308 TEST(ResilientDeviceTest, prepareModel) {
309     // setup call
310     const auto [mockDevice, mockDeviceFactory, device] = setup();
311     const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
312     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
313             .Times(1)
314             .WillOnce(Return(mockPreparedModel));
315 
316     // run test
317     const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
318 
319     // verify result
320     ASSERT_TRUE(result.has_value())
321             << "Failed with " << result.error().code << ": " << result.error().message;
322 }
323 
TEST(ResilientDeviceTest,prepareModelError)324 TEST(ResilientDeviceTest, prepareModelError) {
325     // setup call
326     const auto [mockDevice, mockDeviceFactory, device] = setup();
327     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
328             .Times(1)
329             .WillOnce(kReturnGeneralFailure);
330 
331     // run test
332     const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
333 
334     // verify result
335     ASSERT_FALSE(result.has_value());
336     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
337 }
338 
TEST(ResilientDeviceTest,prepareModelDeadObjectFailedRecovery)339 TEST(ResilientDeviceTest, prepareModelDeadObjectFailedRecovery) {
340     // setup call
341     const auto [mockDevice, mockDeviceFactory, device] = setup();
342     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
343             .Times(1)
344             .WillOnce(kReturnDeadObject);
345     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
346 
347     // run test
348     const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
349 
350     // verify result
351     ASSERT_FALSE(result.has_value());
352     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
353 }
354 
TEST(ResilientDeviceTest,prepareModelDeadObjectSuccessfulRecovery)355 TEST(ResilientDeviceTest, prepareModelDeadObjectSuccessfulRecovery) {
356     // setup call
357     const auto [mockDevice, mockDeviceFactory, device] = setup();
358     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
359             .Times(1)
360             .WillOnce(kReturnDeadObject);
361     const auto recoveredMockDevice = createConfiguredMockDevice();
362     const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
363     EXPECT_CALL(*recoveredMockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
364             .Times(1)
365             .WillOnce(Return(mockPreparedModel));
366     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
367 
368     // run test
369     const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
370 
371     // verify result
372     ASSERT_TRUE(result.has_value())
373             << "Failed with " << result.error().code << ": " << result.error().message;
374 }
375 
TEST(ResilientDeviceTest,prepareModelFromCache)376 TEST(ResilientDeviceTest, prepareModelFromCache) {
377     // setup call
378     const auto [mockDevice, mockDeviceFactory, device] = setup();
379     const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
380     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
381             .Times(1)
382             .WillOnce(Return(mockPreparedModel));
383 
384     // run test
385     const auto result = device->prepareModelFromCache({}, {}, {}, {});
386 
387     // verify result
388     ASSERT_TRUE(result.has_value())
389             << "Failed with " << result.error().code << ": " << result.error().message;
390 }
391 
TEST(ResilientDeviceTest,prepareModelFromCacheError)392 TEST(ResilientDeviceTest, prepareModelFromCacheError) {
393     // setup call
394     const auto [mockDevice, mockDeviceFactory, device] = setup();
395     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
396             .Times(1)
397             .WillOnce(kReturnGeneralFailure);
398 
399     // run test
400     const auto result = device->prepareModelFromCache({}, {}, {}, {});
401 
402     // verify result
403     ASSERT_FALSE(result.has_value());
404     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
405 }
406 
TEST(ResilientDeviceTest,prepareModelFromCacheDeadObjectFailedRecovery)407 TEST(ResilientDeviceTest, prepareModelFromCacheDeadObjectFailedRecovery) {
408     // setup call
409     const auto [mockDevice, mockDeviceFactory, device] = setup();
410     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
411             .Times(1)
412             .WillOnce(kReturnDeadObject);
413     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
414 
415     // run test
416     const auto result = device->prepareModelFromCache({}, {}, {}, {});
417 
418     // verify result
419     ASSERT_FALSE(result.has_value());
420     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
421 }
422 
TEST(ResilientDeviceTest,prepareModelFromCacheDeadObjectSuccessfulRecovery)423 TEST(ResilientDeviceTest, prepareModelFromCacheDeadObjectSuccessfulRecovery) {
424     // setup call
425     const auto [mockDevice, mockDeviceFactory, device] = setup();
426     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
427             .Times(1)
428             .WillOnce(kReturnDeadObject);
429     const auto recoveredMockDevice = createConfiguredMockDevice();
430     const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
431     EXPECT_CALL(*recoveredMockDevice, prepareModelFromCache(_, _, _, _))
432             .Times(1)
433             .WillOnce(Return(mockPreparedModel));
434     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
435 
436     // run test
437     const auto result = device->prepareModelFromCache({}, {}, {}, {});
438 
439     // verify result
440     ASSERT_TRUE(result.has_value())
441             << "Failed with " << result.error().code << ": " << result.error().message;
442 }
443 
TEST(ResilientDeviceTest,allocate)444 TEST(ResilientDeviceTest, allocate) {
445     // setup call
446     const auto [mockDevice, mockDeviceFactory, device] = setup();
447     const auto mockBuffer = std::make_shared<const nn::MockBuffer>();
448     EXPECT_CALL(*mockDevice, allocate(_, _, _, _)).Times(1).WillOnce(Return(mockBuffer));
449 
450     // run test
451     const auto result = device->allocate({}, {}, {}, {});
452 
453     // verify result
454     ASSERT_TRUE(result.has_value())
455             << "Failed with " << result.error().code << ": " << result.error().message;
456 }
457 
TEST(ResilientDeviceTest,allocateError)458 TEST(ResilientDeviceTest, allocateError) {
459     // setup call
460     const auto [mockDevice, mockDeviceFactory, device] = setup();
461     EXPECT_CALL(*mockDevice, allocate(_, _, _, _)).Times(1).WillOnce(kReturnGeneralFailure);
462 
463     // run test
464     const auto result = device->allocate({}, {}, {}, {});
465 
466     // verify result
467     ASSERT_FALSE(result.has_value());
468     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
469 }
470 
TEST(ResilientDeviceTest,allocateDeadObjectFailedRecovery)471 TEST(ResilientDeviceTest, allocateDeadObjectFailedRecovery) {
472     // setup call
473     const auto [mockDevice, mockDeviceFactory, device] = setup();
474     EXPECT_CALL(*mockDevice, allocate(_, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
475     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
476 
477     // run test
478     const auto result = device->allocate({}, {}, {}, {});
479 
480     // verify result
481     ASSERT_FALSE(result.has_value());
482     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
483 }
484 
TEST(ResilientDeviceTest,allocateDeadObjectSuccessfulRecovery)485 TEST(ResilientDeviceTest, allocateDeadObjectSuccessfulRecovery) {
486     // setup call
487     const auto [mockDevice, mockDeviceFactory, device] = setup();
488     EXPECT_CALL(*mockDevice, allocate(_, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
489     const auto recoveredMockDevice = createConfiguredMockDevice();
490     const auto mockBuffer = std::make_shared<const nn::MockBuffer>();
491     EXPECT_CALL(*recoveredMockDevice, allocate(_, _, _, _)).Times(1).WillOnce(Return(mockBuffer));
492     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
493 
494     // run test
495     const auto result = device->allocate({}, {}, {}, {});
496 
497     // verify result
498     ASSERT_TRUE(result.has_value())
499             << "Failed with " << result.error().code << ": " << result.error().message;
500 }
501 
TEST(ResilientDeviceTest,recover)502 TEST(ResilientDeviceTest, recover) {
503     // setup call
504     const auto [mockDevice, mockDeviceFactory, device] = setup();
505     const auto recoveredMockDevice = createConfiguredMockDevice();
506     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
507 
508     // run test
509     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
510 
511     // verify result
512     ASSERT_TRUE(result.has_value())
513             << "Failed with " << result.error().code << ": " << result.error().message;
514     EXPECT_TRUE(result.value() == recoveredMockDevice);
515 }
516 
TEST(ResilientDeviceTest,recoverFailure)517 TEST(ResilientDeviceTest, recoverFailure) {
518     // setup call
519     const auto [mockDevice, mockDeviceFactory, device] = setup();
520     const auto recoveredMockDevice = createConfiguredMockDevice();
521     EXPECT_CALL(*mockDeviceFactory, Call(_)).Times(1).WillOnce(kReturnGeneralFailure);
522 
523     // run test
524     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
525 
526     // verify result
527     EXPECT_FALSE(result.has_value());
528 }
529 
TEST(ResilientDeviceTest,someoneElseRecovered)530 TEST(ResilientDeviceTest, someoneElseRecovered) {
531     // setup call
532     const auto [mockDevice, mockDeviceFactory, device] = setup();
533     const auto recoveredMockDevice = createConfiguredMockDevice();
534     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
535     device->recover(mockDevice.get(), /*blocking=*/false);
536 
537     // run test
538     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
539 
540     // verify result
541     ASSERT_TRUE(result.has_value())
542             << "Failed with " << result.error().code << ": " << result.error().message;
543     EXPECT_TRUE(result.value() == recoveredMockDevice);
544 }
545 
TEST(ResilientDeviceTest,recoverCacheMismatchGetName)546 TEST(ResilientDeviceTest, recoverCacheMismatchGetName) {
547     // setup call
548     const auto [mockDevice, mockDeviceFactory, device] = setup();
549     const auto recoveredMockDevice = createConfiguredMockDevice();
550     const std::string kDifferentName = "Google-DifferentName";
551     const auto ret = [&kDifferentName]() -> const std::string& { return kDifferentName; };
552     EXPECT_CALL(*recoveredMockDevice, getName()).Times(1).WillOnce(ret);
553     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
554 
555     // run test
556     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
557 
558     // verify result
559     ASSERT_TRUE(result.has_value())
560             << "Failed with " << result.error().code << ": " << result.error().message;
561     EXPECT_TRUE(result.value() != nullptr);
562     EXPECT_TRUE(result.value() != mockDevice);
563     EXPECT_TRUE(result.value() != recoveredMockDevice);
564 }
565 
TEST(ResilientDeviceTest,recoverCacheMismatchGetVersionString)566 TEST(ResilientDeviceTest, recoverCacheMismatchGetVersionString) {
567     // setup call
568     const auto [mockDevice, mockDeviceFactory, device] = setup();
569     const auto recoveredMockDevice = createConfiguredMockDevice();
570     const std::string kDifferentVersionString = "differentversion";
571     const auto ret = [&kDifferentVersionString]() -> const std::string& {
572         return kDifferentVersionString;
573     };
574     EXPECT_CALL(*recoveredMockDevice, getVersionString()).Times(1).WillOnce(ret);
575     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
576 
577     // run test
578     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
579 
580     // verify result
581     ASSERT_TRUE(result.has_value())
582             << "Failed with " << result.error().code << ": " << result.error().message;
583     EXPECT_TRUE(result.value() != nullptr);
584     EXPECT_TRUE(result.value() != mockDevice);
585     EXPECT_TRUE(result.value() != recoveredMockDevice);
586 }
587 
TEST(ResilientDeviceTest,recoverCacheMismatchGetFeatureLevel)588 TEST(ResilientDeviceTest, recoverCacheMismatchGetFeatureLevel) {
589     // setup call
590     const auto [mockDevice, mockDeviceFactory, device] = setup();
591     const auto recoveredMockDevice = createConfiguredMockDevice();
592     EXPECT_CALL(*recoveredMockDevice, getFeatureLevel())
593             .Times(1)
594             .WillOnce(Return(nn::kVersionFeatureLevel2));
595     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
596 
597     // run test
598     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
599 
600     // verify result
601     ASSERT_TRUE(result.has_value())
602             << "Failed with " << result.error().code << ": " << result.error().message;
603     EXPECT_TRUE(result.value() != nullptr);
604     EXPECT_TRUE(result.value() != mockDevice);
605     EXPECT_TRUE(result.value() != recoveredMockDevice);
606 }
607 
TEST(ResilientDeviceTest,recoverCacheMismatchGetType)608 TEST(ResilientDeviceTest, recoverCacheMismatchGetType) {
609     // setup call
610     const auto [mockDevice, mockDeviceFactory, device] = setup();
611     const auto recoveredMockDevice = createConfiguredMockDevice();
612     EXPECT_CALL(*recoveredMockDevice, getType()).Times(1).WillOnce(Return(nn::DeviceType::GPU));
613     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
614 
615     // run test
616     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
617 
618     // verify result
619     ASSERT_TRUE(result.has_value())
620             << "Failed with " << result.error().code << ": " << result.error().message;
621     EXPECT_TRUE(result.value() != nullptr);
622     EXPECT_TRUE(result.value() != mockDevice);
623     EXPECT_TRUE(result.value() != recoveredMockDevice);
624 }
625 
TEST(ResilientDeviceTest,recoverCacheMismatchGetSupportedExtensions)626 TEST(ResilientDeviceTest, recoverCacheMismatchGetSupportedExtensions) {
627     // setup call
628     const auto [mockDevice, mockDeviceFactory, device] = setup();
629     const auto recoveredMockDevice = createConfiguredMockDevice();
630     const auto kDifferentExtensions =
631             std::vector<nn::Extension>{nn::Extension{.name = "", .operandTypes = {}}};
632     const auto ret = [&kDifferentExtensions]() -> const std::vector<nn::Extension>& {
633         return kDifferentExtensions;
634     };
635     EXPECT_CALL(*recoveredMockDevice, getSupportedExtensions()).Times(1).WillOnce(ret);
636     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
637 
638     // run test
639     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
640 
641     // verify result
642     ASSERT_TRUE(result.has_value())
643             << "Failed with " << result.error().code << ": " << result.error().message;
644     EXPECT_TRUE(result.value() != nullptr);
645     EXPECT_TRUE(result.value() != mockDevice);
646     EXPECT_TRUE(result.value() != recoveredMockDevice);
647 }
648 
TEST(ResilientDeviceTest,recoverCacheMismatchGetCapabilities)649 TEST(ResilientDeviceTest, recoverCacheMismatchGetCapabilities) {
650     // setup call
651     const auto [mockDevice, mockDeviceFactory, device] = setup();
652     const auto recoveredMockDevice = createConfiguredMockDevice();
653     const auto kDifferentCapabilities = nn::Capabilities{
654             .relaxedFloat32toFloat16PerformanceTensor = {.execTime = 0.5f, .powerUsage = 0.5f},
655             .operandPerformance = nn::Capabilities::OperandPerformanceTable::create({}).value()};
656     const auto ret = [&kDifferentCapabilities]() -> const nn::Capabilities& {
657         return kDifferentCapabilities;
658     };
659     EXPECT_CALL(*recoveredMockDevice, getCapabilities()).Times(1).WillOnce(ret);
660     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
661 
662     // run test
663     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
664 
665     // verify result
666     ASSERT_TRUE(result.has_value())
667             << "Failed with " << result.error().code << ": " << result.error().message;
668     EXPECT_TRUE(result.value() != nullptr);
669     EXPECT_TRUE(result.value() != mockDevice);
670     EXPECT_TRUE(result.value() != recoveredMockDevice);
671 }
672 
TEST(ResilientDeviceTest,recoverCacheMismatchInvalidPrepareModel)673 TEST(ResilientDeviceTest, recoverCacheMismatchInvalidPrepareModel) {
674     // setup call
675     const auto [mockDevice, mockDeviceFactory, device] = setup();
676     const auto recoveredMockDevice = createConfiguredMockDevice();
677     EXPECT_CALL(*recoveredMockDevice, getType()).Times(1).WillOnce(Return(nn::DeviceType::GPU));
678     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
679     device->recover(mockDevice.get(), /*blocking=*/false);
680 
681     // run test
682     auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
683 
684     // verify result
685     ASSERT_TRUE(result.has_value())
686             << "Failed with " << result.error().code << ": " << result.error().message;
687     EXPECT_TRUE(result.value() != nullptr);
688 }
689 
TEST(ResilientDeviceTest,recoverCacheMismatchInvalidPrepareModelFromCache)690 TEST(ResilientDeviceTest, recoverCacheMismatchInvalidPrepareModelFromCache) {
691     // setup call
692     const auto [mockDevice, mockDeviceFactory, device] = setup();
693     const auto recoveredMockDevice = createConfiguredMockDevice();
694     EXPECT_CALL(*recoveredMockDevice, getType()).Times(1).WillOnce(Return(nn::DeviceType::GPU));
695     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
696     device->recover(mockDevice.get(), /*blocking=*/false);
697 
698     // run test
699     auto result = device->prepareModelFromCache({}, {}, {}, {});
700 
701     // verify result
702     ASSERT_TRUE(result.has_value())
703             << "Failed with " << result.error().code << ": " << result.error().message;
704     EXPECT_TRUE(result.value() != nullptr);
705 }
706 
TEST(ResilientDeviceTest,recoverCacheMismatchInvalidAllocate)707 TEST(ResilientDeviceTest, recoverCacheMismatchInvalidAllocate) {
708     // setup call
709     const auto [mockDevice, mockDeviceFactory, device] = setup();
710     const auto recoveredMockDevice = createConfiguredMockDevice();
711     EXPECT_CALL(*recoveredMockDevice, getType()).Times(1).WillOnce(Return(nn::DeviceType::GPU));
712     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
713     device->recover(mockDevice.get(), /*blocking=*/false);
714 
715     // run test
716     auto result = device->allocate({}, {}, {}, {});
717 
718     // verify result
719     ASSERT_TRUE(result.has_value())
720             << "Failed with " << result.error().code << ": " << result.error().message;
721     EXPECT_TRUE(result.value() != nullptr);
722 }
723 
724 }  // namespace android::hardware::neuralnetworks::utils
725