1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <gmock/gmock.h>
18 #include <nnapi/TypeUtils.h>
19 #include <nnapi/Types.h>
20 #include <nnapi/hal/ResilientDevice.h>
21 #include <tuple>
22 #include <utility>
23 #include "MockBuffer.h"
24 #include "MockDevice.h"
25 #include "MockPreparedModel.h"
26 
27 namespace android::hardware::neuralnetworks::utils {
28 namespace {
29 
30 using ::testing::_;
31 using ::testing::InvokeWithoutArgs;
32 using ::testing::Return;
33 
34 using SharedMockDevice = std::shared_ptr<const nn::MockDevice>;
35 using MockDeviceFactory = ::testing::MockFunction<nn::GeneralResult<nn::SharedDevice>(bool)>;
36 
37 const std::string kName = "Google-MockV1";
38 const std::string kVersionString = "version1";
39 const auto kExtensions = std::vector<nn::Extension>{};
40 constexpr auto kNoInfo = std::numeric_limits<float>::max();
41 constexpr auto kNoPerformanceInfo =
42         nn::Capabilities::PerformanceInfo{.execTime = kNoInfo, .powerUsage = kNoInfo};
43 const auto kCapabilities = nn::Capabilities{
44         .relaxedFloat32toFloat16PerformanceScalar = kNoPerformanceInfo,
45         .relaxedFloat32toFloat16PerformanceTensor = kNoPerformanceInfo,
46         .operandPerformance = nn::Capabilities::OperandPerformanceTable::create({}).value(),
47         .ifPerformance = kNoPerformanceInfo,
48         .whilePerformance = kNoPerformanceInfo};
49 constexpr auto kNumberOfCacheFilesNeeded = std::pair<uint32_t, uint32_t>(5, 3);
50 
createConfiguredMockDevice()51 SharedMockDevice createConfiguredMockDevice() {
52     auto mockDevice = std::make_shared<const nn::MockDevice>();
53 
54     // Setup default actions for each relevant call.
55     constexpr auto getName_ret = []() -> const std::string& { return kName; };
56     constexpr auto getVersionString_ret = []() -> const std::string& { return kVersionString; };
57     constexpr auto kFeatureLevel = nn::Version::ANDROID_OC_MR1;
58     constexpr auto kDeviceType = nn::DeviceType::ACCELERATOR;
59     constexpr auto getSupportedExtensions_ret = []() -> const std::vector<nn::Extension>& {
60         return kExtensions;
61     };
62     constexpr auto getCapabilities_ret = []() -> const nn::Capabilities& { return kCapabilities; };
63 
64     // Setup default actions for each relevant call.
65     ON_CALL(*mockDevice, getName()).WillByDefault(getName_ret);
66     ON_CALL(*mockDevice, getVersionString()).WillByDefault(getVersionString_ret);
67     ON_CALL(*mockDevice, getFeatureLevel()).WillByDefault(Return(kFeatureLevel));
68     ON_CALL(*mockDevice, getType()).WillByDefault(Return(kDeviceType));
69     ON_CALL(*mockDevice, getSupportedExtensions()).WillByDefault(getSupportedExtensions_ret);
70     ON_CALL(*mockDevice, getCapabilities()).WillByDefault(getCapabilities_ret);
71     ON_CALL(*mockDevice, getNumberOfCacheFilesNeeded())
72             .WillByDefault(Return(kNumberOfCacheFilesNeeded));
73 
74     // These EXPECT_CALL(...).Times(testing::AnyNumber()) calls are to suppress warnings on the
75     // uninteresting methods calls.
76     EXPECT_CALL(*mockDevice, getName()).Times(testing::AnyNumber());
77     EXPECT_CALL(*mockDevice, getVersionString()).Times(testing::AnyNumber());
78     EXPECT_CALL(*mockDevice, getFeatureLevel()).Times(testing::AnyNumber());
79     EXPECT_CALL(*mockDevice, getType()).Times(testing::AnyNumber());
80     EXPECT_CALL(*mockDevice, getSupportedExtensions()).Times(testing::AnyNumber());
81     EXPECT_CALL(*mockDevice, getCapabilities()).Times(testing::AnyNumber());
82     EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded()).Times(testing::AnyNumber());
83 
84     return mockDevice;
85 }
86 
87 std::tuple<SharedMockDevice, std::unique_ptr<MockDeviceFactory>,
88            std::shared_ptr<const ResilientDevice>>
setup()89 setup() {
90     auto mockDevice = createConfiguredMockDevice();
91 
92     auto mockDeviceFactory = std::make_unique<MockDeviceFactory>();
93     EXPECT_CALL(*mockDeviceFactory, Call(true)).Times(1).WillOnce(Return(mockDevice));
94 
95     auto device = ResilientDevice::create(mockDeviceFactory->AsStdFunction()).value();
96     return std::make_tuple(std::move(mockDevice), std::move(mockDeviceFactory), std::move(device));
97 }
98 
__anon636ef9190602(nn::ErrorStatus status) 99 constexpr auto makeError = [](nn::ErrorStatus status) {
100     return [status](const auto&... /*args*/) { return nn::error(status); };
101 };
102 const auto kReturnGeneralFailure = makeError(nn::ErrorStatus::GENERAL_FAILURE);
103 const auto kReturnDeadObject = makeError(nn::ErrorStatus::DEAD_OBJECT);
104 
105 }  // namespace
106 
TEST(ResilientDeviceTest,invalidDeviceFactory)107 TEST(ResilientDeviceTest, invalidDeviceFactory) {
108     // setup call
109     const auto invalidDeviceFactory = ResilientDevice::Factory{};
110 
111     // run test
112     const auto result = ResilientDevice::create(invalidDeviceFactory);
113 
114     // verify result
115     ASSERT_FALSE(result.has_value());
116     EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
117 }
118 
TEST(ResilientDeviceTest,preparedModelFactoryFailure)119 TEST(ResilientDeviceTest, preparedModelFactoryFailure) {
120     // setup call
121     const auto invalidDeviceFactory = kReturnGeneralFailure;
122 
123     // run test
124     const auto result = ResilientDevice::create(invalidDeviceFactory);
125 
126     // verify result
127     ASSERT_FALSE(result.has_value());
128     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
129 }
130 
TEST(ResilientDeviceTest,cachedData)131 TEST(ResilientDeviceTest, cachedData) {
132     // setup call
133     const auto [mockDevice, mockDeviceFactory, device] = setup();
134 
135     // run test and verify results
136     EXPECT_EQ(device->getName(), kName);
137     EXPECT_EQ(device->getVersionString(), kVersionString);
138     EXPECT_EQ(device->getSupportedExtensions(), kExtensions);
139     EXPECT_EQ(device->getCapabilities(), kCapabilities);
140 }
141 
TEST(ResilientDeviceTest,getFeatureLevel)142 TEST(ResilientDeviceTest, getFeatureLevel) {
143     // setup call
144     const auto [mockDevice, mockDeviceFactory, device] = setup();
145     constexpr auto kFeatureLevel = nn::Version::ANDROID_OC_MR1;
146     EXPECT_CALL(*mockDevice, getFeatureLevel()).Times(1).WillOnce(Return(kFeatureLevel));
147 
148     // run test
149     const auto featureLevel = device->getFeatureLevel();
150 
151     // verify results
152     EXPECT_EQ(featureLevel, kFeatureLevel);
153 }
154 
TEST(ResilientDeviceTest,getType)155 TEST(ResilientDeviceTest, getType) {
156     // setup call
157     const auto [mockDevice, mockDeviceFactory, device] = setup();
158     constexpr auto kDeviceType = nn::DeviceType::ACCELERATOR;
159     EXPECT_CALL(*mockDevice, getType()).Times(1).WillOnce(Return(kDeviceType));
160 
161     // run test
162     const auto type = device->getType();
163 
164     // verify results
165     EXPECT_EQ(type, kDeviceType);
166 }
167 
TEST(ResilientDeviceTest,getNumberOfCacheFilesNeeded)168 TEST(ResilientDeviceTest, getNumberOfCacheFilesNeeded) {
169     // setup call
170     const auto [mockDevice, mockDeviceFactory, device] = setup();
171     EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded())
172             .Times(1)
173             .WillOnce(Return(kNumberOfCacheFilesNeeded));
174 
175     // run test
176     const auto numberOfCacheFilesNeeded = device->getNumberOfCacheFilesNeeded();
177 
178     // verify results
179     EXPECT_EQ(numberOfCacheFilesNeeded, kNumberOfCacheFilesNeeded);
180 }
181 
TEST(ResilientDeviceTest,getDevice)182 TEST(ResilientDeviceTest, getDevice) {
183     // setup call
184     const auto [mockDevice, mockDeviceFactory, device] = setup();
185 
186     // run test
187     const auto result = device->getDevice();
188 
189     // verify result
190     EXPECT_TRUE(result == mockDevice);
191 }
192 
TEST(ResilientDeviceTest,wait)193 TEST(ResilientDeviceTest, wait) {
194     // setup call
195     const auto [mockDevice, mockDeviceFactory, device] = setup();
196     EXPECT_CALL(*mockDevice, wait()).Times(1).WillOnce(Return(nn::GeneralResult<void>{}));
197 
198     // run test
199     const auto result = device->wait();
200 
201     // verify result
202     ASSERT_TRUE(result.has_value())
203             << "Failed with " << result.error().code << ": " << result.error().message;
204 }
205 
TEST(ResilientDeviceTest,waitError)206 TEST(ResilientDeviceTest, waitError) {
207     // setup call
208     const auto [mockDevice, mockDeviceFactory, device] = setup();
209     EXPECT_CALL(*mockDevice, wait()).Times(1).WillOnce(kReturnGeneralFailure);
210 
211     // run test
212     const auto result = device->wait();
213 
214     // verify result
215     ASSERT_FALSE(result.has_value());
216     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
217 }
218 
TEST(ResilientDeviceTest,waitDeadObjectFailedRecovery)219 TEST(ResilientDeviceTest, waitDeadObjectFailedRecovery) {
220     // setup call
221     const auto [mockDevice, mockDeviceFactory, device] = setup();
222     EXPECT_CALL(*mockDevice, wait()).Times(1).WillOnce(kReturnDeadObject);
223     EXPECT_CALL(*mockDeviceFactory, Call(true)).Times(1).WillOnce(kReturnGeneralFailure);
224 
225     // run test
226     const auto result = device->wait();
227 
228     // verify result
229     ASSERT_FALSE(result.has_value());
230     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
231 }
232 
TEST(ResilientDeviceTest,waitDeadObjectSuccessfulRecovery)233 TEST(ResilientDeviceTest, waitDeadObjectSuccessfulRecovery) {
234     // setup call
235     const auto [mockDevice, mockDeviceFactory, device] = setup();
236     EXPECT_CALL(*mockDevice, wait()).Times(1).WillOnce(kReturnDeadObject);
237     const auto recoveredMockDevice = createConfiguredMockDevice();
238     EXPECT_CALL(*recoveredMockDevice, wait()).Times(1).WillOnce(Return(nn::GeneralResult<void>{}));
239     EXPECT_CALL(*mockDeviceFactory, Call(true)).Times(1).WillOnce(Return(recoveredMockDevice));
240 
241     // run test
242     const auto result = device->wait();
243 
244     // verify result
245     ASSERT_TRUE(result.has_value())
246             << "Failed with " << result.error().code << ": " << result.error().message;
247 }
248 
TEST(ResilientDeviceTest,getSupportedOperations)249 TEST(ResilientDeviceTest, getSupportedOperations) {
250     // setup call
251     const auto [mockDevice, mockDeviceFactory, device] = setup();
252     EXPECT_CALL(*mockDevice, getSupportedOperations(_))
253             .Times(1)
254             .WillOnce(Return(nn::GeneralResult<std::vector<bool>>{}));
255 
256     // run test
257     const auto result = device->getSupportedOperations({});
258 
259     // verify result
260     ASSERT_TRUE(result.has_value())
261             << "Failed with " << result.error().code << ": " << result.error().message;
262 }
263 
TEST(ResilientDeviceTest,getSupportedOperationsError)264 TEST(ResilientDeviceTest, getSupportedOperationsError) {
265     // setup call
266     const auto [mockDevice, mockDeviceFactory, device] = setup();
267     EXPECT_CALL(*mockDevice, getSupportedOperations(_)).Times(1).WillOnce(kReturnGeneralFailure);
268 
269     // run test
270     const auto result = device->getSupportedOperations({});
271 
272     // verify result
273     ASSERT_FALSE(result.has_value());
274     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
275 }
276 
TEST(ResilientDeviceTest,getSupportedOperationsDeadObjectFailedRecovery)277 TEST(ResilientDeviceTest, getSupportedOperationsDeadObjectFailedRecovery) {
278     // setup call
279     const auto [mockDevice, mockDeviceFactory, device] = setup();
280     EXPECT_CALL(*mockDevice, getSupportedOperations(_)).Times(1).WillOnce(kReturnDeadObject);
281     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
282 
283     // run test
284     const auto result = device->getSupportedOperations({});
285 
286     // verify result
287     ASSERT_FALSE(result.has_value());
288     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
289 }
290 
TEST(ResilientDeviceTest,getSupportedOperationsDeadObjectSuccessfulRecovery)291 TEST(ResilientDeviceTest, getSupportedOperationsDeadObjectSuccessfulRecovery) {
292     // setup call
293     const auto [mockDevice, mockDeviceFactory, device] = setup();
294     EXPECT_CALL(*mockDevice, getSupportedOperations(_)).Times(1).WillOnce(kReturnDeadObject);
295     const auto recoveredMockDevice = createConfiguredMockDevice();
296     EXPECT_CALL(*recoveredMockDevice, getSupportedOperations(_))
297             .Times(1)
298             .WillOnce(Return(nn::GeneralResult<std::vector<bool>>{}));
299     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
300 
301     // run test
302     const auto result = device->getSupportedOperations({});
303 
304     // verify result
305     ASSERT_TRUE(result.has_value())
306             << "Failed with " << result.error().code << ": " << result.error().message;
307 }
308 
TEST(ResilientDeviceTest,prepareModel)309 TEST(ResilientDeviceTest, prepareModel) {
310     // setup call
311     const auto [mockDevice, mockDeviceFactory, device] = setup();
312     const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
313     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _))
314             .Times(1)
315             .WillOnce(Return(mockPreparedModel));
316 
317     // run test
318     const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
319 
320     // verify result
321     ASSERT_TRUE(result.has_value())
322             << "Failed with " << result.error().code << ": " << result.error().message;
323 }
324 
TEST(ResilientDeviceTest,prepareModelError)325 TEST(ResilientDeviceTest, prepareModelError) {
326     // setup call
327     const auto [mockDevice, mockDeviceFactory, device] = setup();
328     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _))
329             .Times(1)
330             .WillOnce(kReturnGeneralFailure);
331 
332     // run test
333     const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
334 
335     // verify result
336     ASSERT_FALSE(result.has_value());
337     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
338 }
339 
TEST(ResilientDeviceTest,prepareModelDeadObjectFailedRecovery)340 TEST(ResilientDeviceTest, prepareModelDeadObjectFailedRecovery) {
341     // setup call
342     const auto [mockDevice, mockDeviceFactory, device] = setup();
343     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _))
344             .Times(1)
345             .WillOnce(kReturnDeadObject);
346     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
347 
348     // run test
349     const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
350 
351     // verify result
352     ASSERT_FALSE(result.has_value());
353     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
354 }
355 
TEST(ResilientDeviceTest,prepareModelDeadObjectSuccessfulRecovery)356 TEST(ResilientDeviceTest, prepareModelDeadObjectSuccessfulRecovery) {
357     // setup call
358     const auto [mockDevice, mockDeviceFactory, device] = setup();
359     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _))
360             .Times(1)
361             .WillOnce(kReturnDeadObject);
362     const auto recoveredMockDevice = createConfiguredMockDevice();
363     const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
364     EXPECT_CALL(*recoveredMockDevice, prepareModel(_, _, _, _, _, _, _))
365             .Times(1)
366             .WillOnce(Return(mockPreparedModel));
367     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
368 
369     // run test
370     const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
371 
372     // verify result
373     ASSERT_TRUE(result.has_value())
374             << "Failed with " << result.error().code << ": " << result.error().message;
375 }
376 
TEST(ResilientDeviceTest,prepareModelFromCache)377 TEST(ResilientDeviceTest, prepareModelFromCache) {
378     // setup call
379     const auto [mockDevice, mockDeviceFactory, device] = setup();
380     const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
381     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
382             .Times(1)
383             .WillOnce(Return(mockPreparedModel));
384 
385     // run test
386     const auto result = device->prepareModelFromCache({}, {}, {}, {});
387 
388     // verify result
389     ASSERT_TRUE(result.has_value())
390             << "Failed with " << result.error().code << ": " << result.error().message;
391 }
392 
TEST(ResilientDeviceTest,prepareModelFromCacheError)393 TEST(ResilientDeviceTest, prepareModelFromCacheError) {
394     // setup call
395     const auto [mockDevice, mockDeviceFactory, device] = setup();
396     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
397             .Times(1)
398             .WillOnce(kReturnGeneralFailure);
399 
400     // run test
401     const auto result = device->prepareModelFromCache({}, {}, {}, {});
402 
403     // verify result
404     ASSERT_FALSE(result.has_value());
405     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
406 }
407 
TEST(ResilientDeviceTest,prepareModelFromCacheDeadObjectFailedRecovery)408 TEST(ResilientDeviceTest, prepareModelFromCacheDeadObjectFailedRecovery) {
409     // setup call
410     const auto [mockDevice, mockDeviceFactory, device] = setup();
411     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
412             .Times(1)
413             .WillOnce(kReturnDeadObject);
414     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
415 
416     // run test
417     const auto result = device->prepareModelFromCache({}, {}, {}, {});
418 
419     // verify result
420     ASSERT_FALSE(result.has_value());
421     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
422 }
423 
TEST(ResilientDeviceTest,prepareModelFromCacheDeadObjectSuccessfulRecovery)424 TEST(ResilientDeviceTest, prepareModelFromCacheDeadObjectSuccessfulRecovery) {
425     // setup call
426     const auto [mockDevice, mockDeviceFactory, device] = setup();
427     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
428             .Times(1)
429             .WillOnce(kReturnDeadObject);
430     const auto recoveredMockDevice = createConfiguredMockDevice();
431     const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
432     EXPECT_CALL(*recoveredMockDevice, prepareModelFromCache(_, _, _, _))
433             .Times(1)
434             .WillOnce(Return(mockPreparedModel));
435     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
436 
437     // run test
438     const auto result = device->prepareModelFromCache({}, {}, {}, {});
439 
440     // verify result
441     ASSERT_TRUE(result.has_value())
442             << "Failed with " << result.error().code << ": " << result.error().message;
443 }
444 
TEST(ResilientDeviceTest,allocate)445 TEST(ResilientDeviceTest, allocate) {
446     // setup call
447     const auto [mockDevice, mockDeviceFactory, device] = setup();
448     const auto mockBuffer = std::make_shared<const nn::MockBuffer>();
449     EXPECT_CALL(*mockDevice, allocate(_, _, _, _)).Times(1).WillOnce(Return(mockBuffer));
450 
451     // run test
452     const auto result = device->allocate({}, {}, {}, {});
453 
454     // verify result
455     ASSERT_TRUE(result.has_value())
456             << "Failed with " << result.error().code << ": " << result.error().message;
457 }
458 
TEST(ResilientDeviceTest,allocateError)459 TEST(ResilientDeviceTest, allocateError) {
460     // setup call
461     const auto [mockDevice, mockDeviceFactory, device] = setup();
462     EXPECT_CALL(*mockDevice, allocate(_, _, _, _)).Times(1).WillOnce(kReturnGeneralFailure);
463 
464     // run test
465     const auto result = device->allocate({}, {}, {}, {});
466 
467     // verify result
468     ASSERT_FALSE(result.has_value());
469     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
470 }
471 
TEST(ResilientDeviceTest,allocateDeadObjectFailedRecovery)472 TEST(ResilientDeviceTest, allocateDeadObjectFailedRecovery) {
473     // setup call
474     const auto [mockDevice, mockDeviceFactory, device] = setup();
475     EXPECT_CALL(*mockDevice, allocate(_, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
476     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
477 
478     // run test
479     const auto result = device->allocate({}, {}, {}, {});
480 
481     // verify result
482     ASSERT_FALSE(result.has_value());
483     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
484 }
485 
TEST(ResilientDeviceTest,allocateDeadObjectSuccessfulRecovery)486 TEST(ResilientDeviceTest, allocateDeadObjectSuccessfulRecovery) {
487     // setup call
488     const auto [mockDevice, mockDeviceFactory, device] = setup();
489     EXPECT_CALL(*mockDevice, allocate(_, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
490     const auto recoveredMockDevice = createConfiguredMockDevice();
491     const auto mockBuffer = std::make_shared<const nn::MockBuffer>();
492     EXPECT_CALL(*recoveredMockDevice, allocate(_, _, _, _)).Times(1).WillOnce(Return(mockBuffer));
493     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
494 
495     // run test
496     const auto result = device->allocate({}, {}, {}, {});
497 
498     // verify result
499     ASSERT_TRUE(result.has_value())
500             << "Failed with " << result.error().code << ": " << result.error().message;
501 }
502 
TEST(ResilientDeviceTest,recover)503 TEST(ResilientDeviceTest, recover) {
504     // setup call
505     const auto [mockDevice, mockDeviceFactory, device] = setup();
506     const auto recoveredMockDevice = createConfiguredMockDevice();
507     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
508 
509     // run test
510     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
511 
512     // verify result
513     ASSERT_TRUE(result.has_value())
514             << "Failed with " << result.error().code << ": " << result.error().message;
515     EXPECT_TRUE(result.value() == recoveredMockDevice);
516 }
517 
TEST(ResilientDeviceTest,recoverFailure)518 TEST(ResilientDeviceTest, recoverFailure) {
519     // setup call
520     const auto [mockDevice, mockDeviceFactory, device] = setup();
521     const auto recoveredMockDevice = createConfiguredMockDevice();
522     EXPECT_CALL(*mockDeviceFactory, Call(_)).Times(1).WillOnce(kReturnGeneralFailure);
523 
524     // run test
525     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
526 
527     // verify result
528     EXPECT_FALSE(result.has_value());
529 }
530 
TEST(ResilientDeviceTest,someoneElseRecovered)531 TEST(ResilientDeviceTest, someoneElseRecovered) {
532     // setup call
533     const auto [mockDevice, mockDeviceFactory, device] = setup();
534     const auto recoveredMockDevice = createConfiguredMockDevice();
535     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
536     device->recover(mockDevice.get(), /*blocking=*/false);
537 
538     // run test
539     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
540 
541     // verify result
542     ASSERT_TRUE(result.has_value())
543             << "Failed with " << result.error().code << ": " << result.error().message;
544     EXPECT_TRUE(result.value() == recoveredMockDevice);
545 }
546 
TEST(ResilientDeviceTest,recoverCacheMismatchGetName)547 TEST(ResilientDeviceTest, recoverCacheMismatchGetName) {
548     // setup call
549     const auto [mockDevice, mockDeviceFactory, device] = setup();
550     const auto recoveredMockDevice = createConfiguredMockDevice();
551     const std::string kDifferentName = "Google-DifferentName";
552     const auto ret = [&kDifferentName]() -> const std::string& { return kDifferentName; };
553     EXPECT_CALL(*recoveredMockDevice, getName()).Times(1).WillOnce(ret);
554     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
555 
556     // run test
557     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
558 
559     // verify result
560     ASSERT_TRUE(result.has_value())
561             << "Failed with " << result.error().code << ": " << result.error().message;
562     EXPECT_TRUE(result.value() != nullptr);
563     EXPECT_TRUE(result.value() != mockDevice);
564     EXPECT_TRUE(result.value() != recoveredMockDevice);
565 }
566 
TEST(ResilientDeviceTest,recoverCacheMismatchGetVersionString)567 TEST(ResilientDeviceTest, recoverCacheMismatchGetVersionString) {
568     // setup call
569     const auto [mockDevice, mockDeviceFactory, device] = setup();
570     const auto recoveredMockDevice = createConfiguredMockDevice();
571     const std::string kDifferentVersionString = "differentversion";
572     const auto ret = [&kDifferentVersionString]() -> const std::string& {
573         return kDifferentVersionString;
574     };
575     EXPECT_CALL(*recoveredMockDevice, getVersionString()).Times(1).WillOnce(ret);
576     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
577 
578     // run test
579     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
580 
581     // verify result
582     ASSERT_TRUE(result.has_value())
583             << "Failed with " << result.error().code << ": " << result.error().message;
584     EXPECT_TRUE(result.value() != nullptr);
585     EXPECT_TRUE(result.value() != mockDevice);
586     EXPECT_TRUE(result.value() != recoveredMockDevice);
587 }
588 
TEST(ResilientDeviceTest,recoverCacheMismatchGetFeatureLevel)589 TEST(ResilientDeviceTest, recoverCacheMismatchGetFeatureLevel) {
590     // setup call
591     const auto [mockDevice, mockDeviceFactory, device] = setup();
592     const auto recoveredMockDevice = createConfiguredMockDevice();
593     EXPECT_CALL(*recoveredMockDevice, getFeatureLevel())
594             .Times(1)
595             .WillOnce(Return(nn::Version::ANDROID_P));
596     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
597 
598     // run test
599     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
600 
601     // verify result
602     ASSERT_TRUE(result.has_value())
603             << "Failed with " << result.error().code << ": " << result.error().message;
604     EXPECT_TRUE(result.value() != nullptr);
605     EXPECT_TRUE(result.value() != mockDevice);
606     EXPECT_TRUE(result.value() != recoveredMockDevice);
607 }
608 
TEST(ResilientDeviceTest,recoverCacheMismatchGetType)609 TEST(ResilientDeviceTest, recoverCacheMismatchGetType) {
610     // setup call
611     const auto [mockDevice, mockDeviceFactory, device] = setup();
612     const auto recoveredMockDevice = createConfiguredMockDevice();
613     EXPECT_CALL(*recoveredMockDevice, getType()).Times(1).WillOnce(Return(nn::DeviceType::GPU));
614     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
615 
616     // run test
617     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
618 
619     // verify result
620     ASSERT_TRUE(result.has_value())
621             << "Failed with " << result.error().code << ": " << result.error().message;
622     EXPECT_TRUE(result.value() != nullptr);
623     EXPECT_TRUE(result.value() != mockDevice);
624     EXPECT_TRUE(result.value() != recoveredMockDevice);
625 }
626 
TEST(ResilientDeviceTest,recoverCacheMismatchGetSupportedExtensions)627 TEST(ResilientDeviceTest, recoverCacheMismatchGetSupportedExtensions) {
628     // setup call
629     const auto [mockDevice, mockDeviceFactory, device] = setup();
630     const auto recoveredMockDevice = createConfiguredMockDevice();
631     const auto kDifferentExtensions =
632             std::vector<nn::Extension>{nn::Extension{.name = "", .operandTypes = {}}};
633     const auto ret = [&kDifferentExtensions]() -> const std::vector<nn::Extension>& {
634         return kDifferentExtensions;
635     };
636     EXPECT_CALL(*recoveredMockDevice, getSupportedExtensions()).Times(1).WillOnce(ret);
637     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
638 
639     // run test
640     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
641 
642     // verify result
643     ASSERT_TRUE(result.has_value())
644             << "Failed with " << result.error().code << ": " << result.error().message;
645     EXPECT_TRUE(result.value() != nullptr);
646     EXPECT_TRUE(result.value() != mockDevice);
647     EXPECT_TRUE(result.value() != recoveredMockDevice);
648 }
649 
TEST(ResilientDeviceTest,recoverCacheMismatchGetCapabilities)650 TEST(ResilientDeviceTest, recoverCacheMismatchGetCapabilities) {
651     // setup call
652     const auto [mockDevice, mockDeviceFactory, device] = setup();
653     const auto recoveredMockDevice = createConfiguredMockDevice();
654     const auto kDifferentCapabilities = nn::Capabilities{
655             .relaxedFloat32toFloat16PerformanceTensor = {.execTime = 0.5f, .powerUsage = 0.5f},
656             .operandPerformance = nn::Capabilities::OperandPerformanceTable::create({}).value()};
657     const auto ret = [&kDifferentCapabilities]() -> const nn::Capabilities& {
658         return kDifferentCapabilities;
659     };
660     EXPECT_CALL(*recoveredMockDevice, getCapabilities()).Times(1).WillOnce(ret);
661     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
662 
663     // run test
664     const auto result = device->recover(mockDevice.get(), /*blocking=*/false);
665 
666     // verify result
667     ASSERT_TRUE(result.has_value())
668             << "Failed with " << result.error().code << ": " << result.error().message;
669     EXPECT_TRUE(result.value() != nullptr);
670     EXPECT_TRUE(result.value() != mockDevice);
671     EXPECT_TRUE(result.value() != recoveredMockDevice);
672 }
673 
TEST(ResilientDeviceTest,recoverCacheMismatchInvalidPrepareModel)674 TEST(ResilientDeviceTest, recoverCacheMismatchInvalidPrepareModel) {
675     // setup call
676     const auto [mockDevice, mockDeviceFactory, device] = setup();
677     const auto recoveredMockDevice = createConfiguredMockDevice();
678     EXPECT_CALL(*recoveredMockDevice, getType()).Times(1).WillOnce(Return(nn::DeviceType::GPU));
679     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
680     device->recover(mockDevice.get(), /*blocking=*/false);
681 
682     // run test
683     auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
684 
685     // verify result
686     ASSERT_TRUE(result.has_value())
687             << "Failed with " << result.error().code << ": " << result.error().message;
688     EXPECT_TRUE(result.value() != nullptr);
689 }
690 
TEST(ResilientDeviceTest,recoverCacheMismatchInvalidPrepareModelFromCache)691 TEST(ResilientDeviceTest, recoverCacheMismatchInvalidPrepareModelFromCache) {
692     // setup call
693     const auto [mockDevice, mockDeviceFactory, device] = setup();
694     const auto recoveredMockDevice = createConfiguredMockDevice();
695     EXPECT_CALL(*recoveredMockDevice, getType()).Times(1).WillOnce(Return(nn::DeviceType::GPU));
696     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
697     device->recover(mockDevice.get(), /*blocking=*/false);
698 
699     // run test
700     auto result = device->prepareModelFromCache({}, {}, {}, {});
701 
702     // verify result
703     ASSERT_TRUE(result.has_value())
704             << "Failed with " << result.error().code << ": " << result.error().message;
705     EXPECT_TRUE(result.value() != nullptr);
706 }
707 
TEST(ResilientDeviceTest,recoverCacheMismatchInvalidAllocate)708 TEST(ResilientDeviceTest, recoverCacheMismatchInvalidAllocate) {
709     // setup call
710     const auto [mockDevice, mockDeviceFactory, device] = setup();
711     const auto recoveredMockDevice = createConfiguredMockDevice();
712     EXPECT_CALL(*recoveredMockDevice, getType()).Times(1).WillOnce(Return(nn::DeviceType::GPU));
713     EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
714     device->recover(mockDevice.get(), /*blocking=*/false);
715 
716     // run test
717     auto result = device->allocate({}, {}, {}, {});
718 
719     // verify result
720     ASSERT_TRUE(result.has_value())
721             << "Failed with " << result.error().code << ": " << result.error().message;
722     EXPECT_TRUE(result.value() != nullptr);
723 }
724 
725 }  // namespace android::hardware::neuralnetworks::utils
726