1 /******************************************************************************
2 *
3 * Copyright 2020 Google, Inc.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at:
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 ******************************************************************************/
18
19 #include "os/internal/wakelock_native.h"
20
21 #include <aidl/android/system/suspend/BnSuspendCallback.h>
22 #include <aidl/android/system/suspend/BnWakelockCallback.h>
23 #include <aidl/android/system/suspend/ISuspendControlService.h>
24 #include <android/binder_auto_utils.h>
25 #include <android/binder_interface_utils.h>
26 #include <android/binder_manager.h>
27 #include <android/binder_process.h>
28 #include <gtest/gtest.h>
29
30 #include <chrono>
31 #include <future>
32 #include <memory>
33 #include <mutex>
34
35 namespace testing {
36
37 using aidl::android::system::suspend::BnSuspendCallback;
38 using aidl::android::system::suspend::BnWakelockCallback;
39 using aidl::android::system::suspend::ISuspendControlService;
40 using bluetooth::os::internal::WakelockNative;
41 using ndk::ScopedAStatus;
42 using ndk::SharedRefBase;
43 using ndk::SpAIBinder;
44
45 static const std::string kTestWakelockName = "BtWakelockNativeTestLock";
46
47 static std::recursive_mutex mutex;
48 static std::unique_ptr<std::promise<void>> acquire_promise = nullptr;
49 static std::unique_ptr<std::promise<void>> release_promise = nullptr;
50
51 class PromiseFutureContext {
52 public:
FulfilPromise(std::unique_ptr<std::promise<void>> & promise)53 static void FulfilPromise(std::unique_ptr<std::promise<void>>& promise) {
54 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
55 if (promise != nullptr) {
56 promise->set_value();
57 promise = nullptr;
58 }
59 }
60
PromiseFutureContext(std::unique_ptr<std::promise<void>> & promise,bool expect_fulfillment)61 explicit PromiseFutureContext(std::unique_ptr<std::promise<void>>& promise, bool expect_fulfillment)
62 : promise_(promise), expect_fulfillment_(expect_fulfillment) {
63 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
64 EXPECT_EQ(promise_, nullptr);
65 promise_ = std::make_unique<std::promise<void>>();
66 future_ = promise->get_future();
67 }
68
~PromiseFutureContext()69 ~PromiseFutureContext() {
70 auto future_status = future_.wait_for(std::chrono::seconds(2));
71 if (expect_fulfillment_) {
72 EXPECT_EQ(future_status, std::future_status::ready);
73 } else {
74 EXPECT_NE(future_status, std::future_status::ready);
75 }
76 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
77 promise_ = nullptr;
78 }
79
80 private:
81 std::unique_ptr<std::promise<void>>& promise_;
82 bool expect_fulfillment_ = true;
83 std::future<void> future_;
84 };
85
86 class WakelockCallback : public BnWakelockCallback {
87 public:
notifyAcquired()88 ScopedAStatus notifyAcquired() override {
89 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
90 net_acquired_count++;
91 fprintf(stderr, "notifyAcquired, count = %d\n", net_acquired_count);
92 PromiseFutureContext::FulfilPromise(acquire_promise);
93 return ScopedAStatus::ok();
94 }
notifyReleased()95 ScopedAStatus notifyReleased() override {
96 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
97 net_acquired_count--;
98 fprintf(stderr, "notifyReleased, count = %d\n", net_acquired_count);
99 PromiseFutureContext::FulfilPromise(release_promise);
100 return ScopedAStatus::ok();
101 }
102
103 int net_acquired_count = 0;
104 };
105
106 class SuspendCallback : public BnSuspendCallback {
107 public:
notifyWakeup(bool success,const std::vector<std::string> & wakeup_reasons)108 ScopedAStatus notifyWakeup(bool success, const std::vector<std::string>& wakeup_reasons) override {
109 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
110 fprintf(stderr, "notifyWakeup\n");
111 return ScopedAStatus::ok();
112 }
113 };
114
115 // There is no way to unregister these callbacks besides when this process dies
116 // Hence, we want to have only one copy of these callbacks per process
117 static std::shared_ptr<SuspendCallback> suspend_callback = nullptr;
118 static std::shared_ptr<WakelockCallback> control_callback = nullptr;
119
120 class WakelockNativeTest : public Test {
121 protected:
SetUp()122 void SetUp() override {
123 ABinderProcess_setThreadPoolMaxThreadCount(1);
124 ABinderProcess_startThreadPool();
125
126 WakelockNative::Get().Initialize();
127
128 auto binder_raw = AServiceManager_getService("suspend_control");
129 ASSERT_NE(binder_raw, nullptr);
130 binder.set(binder_raw);
131 control_service_ = ISuspendControlService::fromBinder(binder);
132 if (control_service_ == nullptr) {
133 FAIL() << "Fail to obtain suspend_control";
134 }
135
136 if (suspend_callback == nullptr) {
137 suspend_callback = SharedRefBase::make<SuspendCallback>();
138 bool is_registered = false;
139 ScopedAStatus status = control_service_->registerCallback(suspend_callback, &is_registered);
140 if (!is_registered || !status.isOk()) {
141 FAIL() << "Fail to register suspend callback";
142 }
143 }
144
145 if (control_callback == nullptr) {
146 control_callback = SharedRefBase::make<WakelockCallback>();
147 bool is_registered = false;
148 ScopedAStatus status =
149 control_service_->registerWakelockCallback(control_callback, kTestWakelockName, &is_registered);
150 if (!is_registered || !status.isOk()) {
151 FAIL() << "Fail to register wakeup callback";
152 }
153 }
154 control_callback->net_acquired_count = 0;
155 }
156
TearDown()157 void TearDown() override {
158 control_service_ = nullptr;
159 binder.set(nullptr);
160 WakelockNative::Get().CleanUp();
161 }
162
163 SpAIBinder binder;
164 std::shared_ptr<ISuspendControlService> control_service_ = nullptr;
165 };
166
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks)167 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks) {
168 ASSERT_EQ(control_callback->net_acquired_count, 0);
169
170 {
171 PromiseFutureContext context(acquire_promise, true);
172 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
173 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
174 }
175 ASSERT_EQ(control_callback->net_acquired_count, 1);
176
177 {
178 PromiseFutureContext context(release_promise, true);
179 auto status = WakelockNative::Get().Release(kTestWakelockName);
180 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
181 }
182 ASSERT_EQ(control_callback->net_acquired_count, 0);
183 }
184
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_repeated_acquire)185 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_repeated_acquire) {
186 ASSERT_EQ(control_callback->net_acquired_count, 0);
187
188 {
189 PromiseFutureContext context(acquire_promise, true);
190 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
191 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
192 }
193 ASSERT_EQ(control_callback->net_acquired_count, 1);
194
195 {
196 PromiseFutureContext context(acquire_promise, false);
197 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
198 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
199 }
200 ASSERT_EQ(control_callback->net_acquired_count, 1);
201
202 {
203 PromiseFutureContext context(release_promise, true);
204 auto status = WakelockNative::Get().Release(kTestWakelockName);
205 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
206 }
207 ASSERT_EQ(control_callback->net_acquired_count, 0);
208 }
209
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_repeated_release)210 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_repeated_release) {
211 ASSERT_EQ(control_callback->net_acquired_count, 0);
212
213 {
214 PromiseFutureContext context(acquire_promise, true);
215 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
216 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
217 }
218 ASSERT_EQ(control_callback->net_acquired_count, 1);
219
220 {
221 PromiseFutureContext context(release_promise, true);
222 auto status = WakelockNative::Get().Release(kTestWakelockName);
223 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
224 }
225 ASSERT_EQ(control_callback->net_acquired_count, 0);
226
227 {
228 PromiseFutureContext context(release_promise, false);
229 auto status = WakelockNative::Get().Release(kTestWakelockName);
230 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
231 }
232 ASSERT_EQ(control_callback->net_acquired_count, 0);
233 }
234
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_in_a_loop)235 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_in_a_loop) {
236 ASSERT_EQ(control_callback->net_acquired_count, 0);
237
238 for (int i = 0; i < 10; ++i) {
239 {
240 PromiseFutureContext context(acquire_promise, true);
241 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
242 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
243 }
244 ASSERT_EQ(control_callback->net_acquired_count, 1);
245
246 {
247 PromiseFutureContext context(release_promise, true);
248 auto status = WakelockNative::Get().Release(kTestWakelockName);
249 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
250 }
251 ASSERT_EQ(control_callback->net_acquired_count, 0);
252 }
253 }
254
TEST_F(WakelockNativeTest,test_clean_up)255 TEST_F(WakelockNativeTest, test_clean_up) {
256 WakelockNative::Get().Initialize();
257 ASSERT_EQ(control_callback->net_acquired_count, 0);
258
259 {
260 PromiseFutureContext context(acquire_promise, true);
261 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
262 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
263 }
264 ASSERT_EQ(control_callback->net_acquired_count, 1);
265
266 {
267 PromiseFutureContext context(release_promise, true);
268 WakelockNative::Get().CleanUp();
269 }
270 ASSERT_EQ(control_callback->net_acquired_count, 0);
271 }
272
273 } // namespace testing