1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <gmock/gmock.h>
18 #include <nnapi/TypeUtils.h>
19 #include <nnapi/Types.h>
20 #include <nnapi/hal/ResilientExecution.h>
21 #include <utility>
22 #include "MockExecution.h"
23 
24 namespace android::hardware::neuralnetworks::utils {
25 namespace {
26 
27 using ::testing::_;
28 using ::testing::InvokeWithoutArgs;
29 using ::testing::Return;
30 
31 using SharedMockExecution = std::shared_ptr<const nn::MockExecution>;
32 using MockExecutionFactory = ::testing::MockFunction<nn::GeneralResult<nn::SharedExecution>()>;
33 
createMockExecution()34 SharedMockExecution createMockExecution() {
35     return std::make_shared<const nn::MockExecution>();
36 }
37 
38 std::tuple<SharedMockExecution, std::unique_ptr<MockExecutionFactory>,
39            std::shared_ptr<const ResilientExecution>>
setup()40 setup() {
41     auto mockExecution = std::make_shared<const nn::MockExecution>();
42 
43     auto mockExecutionFactory = std::make_unique<MockExecutionFactory>();
44     EXPECT_CALL(*mockExecutionFactory, Call()).Times(1).WillOnce(Return(mockExecution));
45 
46     auto buffer = ResilientExecution::create(mockExecutionFactory->AsStdFunction()).value();
47     return std::make_tuple(std::move(mockExecution), std::move(mockExecutionFactory),
48                            std::move(buffer));
49 }
50 
__anon1bab929d0202(nn::ErrorStatus status) 51 constexpr auto makeError = [](nn::ErrorStatus status) {
52     return [status](const auto&... /*args*/) { return nn::error(status); };
53 };
54 const auto kReturnGeneralFailure = makeError(nn::ErrorStatus::GENERAL_FAILURE);
55 const auto kReturnDeadObject = makeError(nn::ErrorStatus::DEAD_OBJECT);
56 
57 const auto kNoExecutionError =
58         nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>{};
59 const auto kNoFencedExecutionError =
60         nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>(
61                 std::make_pair(nn::SyncFence::createAsSignaled(), nullptr));
62 
63 }  // namespace
64 
TEST(ResilientExecutionTest,invalidExecutionFactory)65 TEST(ResilientExecutionTest, invalidExecutionFactory) {
66     // setup call
67     const auto invalidExecutionFactory = ResilientExecution::Factory{};
68 
69     // run test
70     const auto result = ResilientExecution::create(invalidExecutionFactory);
71 
72     // verify result
73     ASSERT_FALSE(result.has_value());
74     EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
75 }
76 
TEST(ResilientExecutionTest,executionFactoryFailure)77 TEST(ResilientExecutionTest, executionFactoryFailure) {
78     // setup call
79     const auto invalidExecutionFactory = kReturnGeneralFailure;
80 
81     // run test
82     const auto result = ResilientExecution::create(invalidExecutionFactory);
83 
84     // verify result
85     ASSERT_FALSE(result.has_value());
86     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
87 }
88 
TEST(ResilientExecutionTest,getExecution)89 TEST(ResilientExecutionTest, getExecution) {
90     // setup call
91     const auto [mockExecution, mockExecutionFactory, execution] = setup();
92 
93     // run test
94     const auto result = execution->getExecution();
95 
96     // verify result
97     EXPECT_TRUE(result == mockExecution);
98 }
99 
TEST(ResilientExecutionTest,compute)100 TEST(ResilientExecutionTest, compute) {
101     // setup call
102     const auto [mockExecution, mockExecutionFactory, execution] = setup();
103     EXPECT_CALL(*mockExecution, compute(_)).Times(1).WillOnce(Return(kNoExecutionError));
104 
105     // run test
106     const auto result = execution->compute({});
107 
108     // verify result
109     ASSERT_TRUE(result.has_value())
110             << "Failed with " << result.error().code << ": " << result.error().message;
111 }
112 
TEST(ResilientExecutionTest,computeError)113 TEST(ResilientExecutionTest, computeError) {
114     // setup call
115     const auto [mockExecution, mockExecutionFactory, execution] = setup();
116     EXPECT_CALL(*mockExecution, compute(_)).Times(1).WillOnce(kReturnGeneralFailure);
117 
118     // run test
119     const auto result = execution->compute({});
120 
121     // verify result
122     ASSERT_FALSE(result.has_value());
123     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
124 }
125 
TEST(ResilientExecutionTest,computeDeadObjectFailedRecovery)126 TEST(ResilientExecutionTest, computeDeadObjectFailedRecovery) {
127     // setup call
128     const auto [mockExecution, mockExecutionFactory, execution] = setup();
129     EXPECT_CALL(*mockExecution, compute(_)).Times(1).WillOnce(kReturnDeadObject);
130     EXPECT_CALL(*mockExecutionFactory, Call()).Times(1).WillOnce(kReturnGeneralFailure);
131 
132     // run test
133     const auto result = execution->compute({});
134 
135     // verify result
136     ASSERT_FALSE(result.has_value());
137     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
138 }
139 
TEST(ResilientExecutionTest,computeDeadObjectSuccessfulRecovery)140 TEST(ResilientExecutionTest, computeDeadObjectSuccessfulRecovery) {
141     // setup call
142     const auto [mockExecution, mockExecutionFactory, execution] = setup();
143     EXPECT_CALL(*mockExecution, compute(_)).Times(1).WillOnce(kReturnDeadObject);
144     const auto recoveredMockExecution = createMockExecution();
145     EXPECT_CALL(*recoveredMockExecution, compute(_)).Times(1).WillOnce(Return(kNoExecutionError));
146     EXPECT_CALL(*mockExecutionFactory, Call()).Times(1).WillOnce(Return(recoveredMockExecution));
147 
148     // run test
149     const auto result = execution->compute({});
150 
151     // verify result
152     ASSERT_TRUE(result.has_value())
153             << "Failed with " << result.error().code << ": " << result.error().message;
154 }
155 
TEST(ResilientExecutionTest,computeFenced)156 TEST(ResilientExecutionTest, computeFenced) {
157     // setup call
158     const auto [mockExecution, mockExecutionFactory, execution] = setup();
159     EXPECT_CALL(*mockExecution, computeFenced(_, _, _))
160             .Times(1)
161             .WillOnce(Return(kNoFencedExecutionError));
162 
163     // run test
164     const auto result = execution->computeFenced({}, {}, {});
165 
166     // verify result
167     ASSERT_TRUE(result.has_value())
168             << "Failed with " << result.error().code << ": " << result.error().message;
169 }
170 
TEST(ResilientExecutionTest,computeFencedError)171 TEST(ResilientExecutionTest, computeFencedError) {
172     // setup call
173     const auto [mockExecution, mockExecutionFactory, execution] = setup();
174     EXPECT_CALL(*mockExecution, computeFenced(_, _, _)).Times(1).WillOnce(kReturnGeneralFailure);
175 
176     // run test
177     const auto result = execution->computeFenced({}, {}, {});
178 
179     // verify result
180     ASSERT_FALSE(result.has_value());
181     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
182 }
183 
TEST(ResilientExecutionTest,computeFencedDeadObjectFailedRecovery)184 TEST(ResilientExecutionTest, computeFencedDeadObjectFailedRecovery) {
185     // setup call
186     const auto [mockExecution, mockExecutionFactory, execution] = setup();
187     EXPECT_CALL(*mockExecution, computeFenced(_, _, _)).Times(1).WillOnce(kReturnDeadObject);
188     EXPECT_CALL(*mockExecutionFactory, Call()).Times(1).WillOnce(kReturnGeneralFailure);
189 
190     // run test
191     const auto result = execution->computeFenced({}, {}, {});
192 
193     // verify result
194     ASSERT_FALSE(result.has_value());
195     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
196 }
197 
TEST(ResilientExecutionTest,computeFencedDeadObjectSuccessfulRecovery)198 TEST(ResilientExecutionTest, computeFencedDeadObjectSuccessfulRecovery) {
199     // setup call
200     const auto [mockExecution, mockExecutionFactory, execution] = setup();
201     EXPECT_CALL(*mockExecution, computeFenced(_, _, _)).Times(1).WillOnce(kReturnDeadObject);
202     const auto recoveredMockExecution = createMockExecution();
203     EXPECT_CALL(*recoveredMockExecution, computeFenced(_, _, _))
204             .Times(1)
205             .WillOnce(Return(kNoFencedExecutionError));
206     EXPECT_CALL(*mockExecutionFactory, Call()).Times(1).WillOnce(Return(recoveredMockExecution));
207 
208     // run test
209     const auto result = execution->computeFenced({}, {}, {});
210 
211     // verify result
212     ASSERT_TRUE(result.has_value())
213             << "Failed with " << result.error().code << ": " << result.error().message;
214 }
215 
TEST(ResilientExecutionTest,recover)216 TEST(ResilientExecutionTest, recover) {
217     // setup call
218     const auto [mockExecution, mockExecutionFactory, execution] = setup();
219     const auto recoveredMockExecution = createMockExecution();
220     EXPECT_CALL(*mockExecutionFactory, Call()).Times(1).WillOnce(Return(recoveredMockExecution));
221 
222     // run test
223     const auto result = execution->recover(mockExecution.get());
224 
225     // verify result
226     ASSERT_TRUE(result.has_value())
227             << "Failed with " << result.error().code << ": " << result.error().message;
228     EXPECT_TRUE(result.value() == recoveredMockExecution);
229 }
230 
TEST(ResilientExecutionTest,recoverFailure)231 TEST(ResilientExecutionTest, recoverFailure) {
232     // setup call
233     const auto [mockExecution, mockExecutionFactory, execution] = setup();
234     const auto recoveredMockExecution = createMockExecution();
235     EXPECT_CALL(*mockExecutionFactory, Call()).Times(1).WillOnce(kReturnGeneralFailure);
236 
237     // run test
238     const auto result = execution->recover(mockExecution.get());
239 
240     // verify result
241     EXPECT_FALSE(result.has_value());
242 }
243 
TEST(ResilientExecutionTest,someoneElseRecovered)244 TEST(ResilientExecutionTest, someoneElseRecovered) {
245     // setup call
246     const auto [mockExecution, mockExecutionFactory, execution] = setup();
247     const auto recoveredMockExecution = createMockExecution();
248     EXPECT_CALL(*mockExecutionFactory, Call()).Times(1).WillOnce(Return(recoveredMockExecution));
249     execution->recover(mockExecution.get());
250 
251     // run test
252     const auto result = execution->recover(mockExecution.get());
253 
254     // verify result
255     ASSERT_TRUE(result.has_value())
256             << "Failed with " << result.error().code << ": " << result.error().message;
257     EXPECT_TRUE(result.value() == recoveredMockExecution);
258 }
259 
260 }  // namespace android::hardware::neuralnetworks::utils
261