1 /******************************************************************************
2  *
3  *  Copyright 2020 Google, Inc.
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at:
8  *
9  *  http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  ******************************************************************************/
18 
19 #include "os/internal/wakelock_native.h"
20 
21 #include <aidl/android/system/suspend/BnSuspendCallback.h>
22 #include <aidl/android/system/suspend/BnWakelockCallback.h>
23 #include <aidl/android/system/suspend/ISuspendControlService.h>
24 #include <android/binder_auto_utils.h>
25 #include <android/binder_interface_utils.h>
26 #include <android/binder_manager.h>
27 #include <android/binder_process.h>
28 #include <gtest/gtest.h>
29 
30 #include <chrono>
31 #include <future>
32 #include <memory>
33 #include <mutex>
34 
35 namespace testing {
36 
37 using aidl::android::system::suspend::BnSuspendCallback;
38 using aidl::android::system::suspend::BnWakelockCallback;
39 using aidl::android::system::suspend::ISuspendControlService;
40 using bluetooth::os::internal::WakelockNative;
41 using ndk::ScopedAStatus;
42 using ndk::SharedRefBase;
43 using ndk::SpAIBinder;
44 
45 static const std::string kTestWakelockName = "BtWakelockNativeTestLock";
46 
47 static std::recursive_mutex mutex;
48 static std::unique_ptr<std::promise<void>> acquire_promise = nullptr;
49 static std::unique_ptr<std::promise<void>> release_promise = nullptr;
50 
51 class PromiseFutureContext {
52  public:
FulfilPromise(std::unique_ptr<std::promise<void>> & promise)53   static void FulfilPromise(std::unique_ptr<std::promise<void>>& promise) {
54     std::lock_guard<std::recursive_mutex> lock_guard(mutex);
55     if (promise != nullptr) {
56       promise->set_value();
57       promise = nullptr;
58     }
59   }
60 
PromiseFutureContext(std::unique_ptr<std::promise<void>> & promise,bool expect_fulfillment)61   explicit PromiseFutureContext(std::unique_ptr<std::promise<void>>& promise, bool expect_fulfillment)
62       : promise_(promise), expect_fulfillment_(expect_fulfillment) {
63     std::lock_guard<std::recursive_mutex> lock_guard(mutex);
64     EXPECT_EQ(promise_, nullptr);
65     promise_ = std::make_unique<std::promise<void>>();
66     future_ = promise->get_future();
67   }
68 
~PromiseFutureContext()69   ~PromiseFutureContext() {
70     auto future_status = future_.wait_for(std::chrono::seconds(2));
71     if (expect_fulfillment_) {
72       EXPECT_EQ(future_status, std::future_status::ready);
73     } else {
74       EXPECT_NE(future_status, std::future_status::ready);
75     }
76     std::lock_guard<std::recursive_mutex> lock_guard(mutex);
77     promise_ = nullptr;
78   }
79 
80  private:
81   std::unique_ptr<std::promise<void>>& promise_;
82   bool expect_fulfillment_ = true;
83   std::future<void> future_;
84 };
85 
86 class WakelockCallback : public BnWakelockCallback {
87  public:
notifyAcquired()88   ScopedAStatus notifyAcquired() override {
89     std::lock_guard<std::recursive_mutex> lock_guard(mutex);
90     net_acquired_count++;
91     fprintf(stderr, "notifyAcquired, count = %d\n", net_acquired_count);
92     PromiseFutureContext::FulfilPromise(acquire_promise);
93     return ScopedAStatus::ok();
94   }
notifyReleased()95   ScopedAStatus notifyReleased() override {
96     std::lock_guard<std::recursive_mutex> lock_guard(mutex);
97     net_acquired_count--;
98     fprintf(stderr, "notifyReleased, count = %d\n", net_acquired_count);
99     PromiseFutureContext::FulfilPromise(release_promise);
100     return ScopedAStatus::ok();
101   }
102 
103   int net_acquired_count = 0;
104 };
105 
106 class SuspendCallback : public BnSuspendCallback {
107  public:
notifyWakeup(bool success,const std::vector<std::string> & wakeup_reasons)108   ScopedAStatus notifyWakeup(bool success, const std::vector<std::string>& wakeup_reasons) override {
109     std::lock_guard<std::recursive_mutex> lock_guard(mutex);
110     fprintf(stderr, "notifyWakeup\n");
111     return ScopedAStatus::ok();
112   }
113 };
114 
115 // There is no way to unregister these callbacks besides when this process dies
116 // Hence, we want to have only one copy of these callbacks per process
117 static std::shared_ptr<SuspendCallback> suspend_callback = nullptr;
118 static std::shared_ptr<WakelockCallback> control_callback = nullptr;
119 
120 class WakelockNativeTest : public Test {
121  protected:
SetUp()122   void SetUp() override {
123     ABinderProcess_setThreadPoolMaxThreadCount(1);
124     ABinderProcess_startThreadPool();
125 
126     WakelockNative::Get().Initialize();
127 
128     auto binder_raw = AServiceManager_getService("suspend_control");
129     ASSERT_NE(binder_raw, nullptr);
130     binder.set(binder_raw);
131     control_service_ = ISuspendControlService::fromBinder(binder);
132     if (control_service_ == nullptr) {
133       FAIL() << "Fail to obtain suspend_control";
134     }
135 
136     if (suspend_callback == nullptr) {
137       suspend_callback = SharedRefBase::make<SuspendCallback>();
138       bool is_registered = false;
139       ScopedAStatus status = control_service_->registerCallback(suspend_callback, &is_registered);
140       if (!is_registered || !status.isOk()) {
141         FAIL() << "Fail to register suspend callback";
142       }
143     }
144 
145     if (control_callback == nullptr) {
146       control_callback = SharedRefBase::make<WakelockCallback>();
147       bool is_registered = false;
148       ScopedAStatus status =
149           control_service_->registerWakelockCallback(control_callback, kTestWakelockName, &is_registered);
150       if (!is_registered || !status.isOk()) {
151         FAIL() << "Fail to register wakeup callback";
152       }
153     }
154     control_callback->net_acquired_count = 0;
155   }
156 
TearDown()157   void TearDown() override {
158     control_service_ = nullptr;
159     binder.set(nullptr);
160     WakelockNative::Get().CleanUp();
161   }
162 
163   SpAIBinder binder;
164   std::shared_ptr<ISuspendControlService> control_service_ = nullptr;
165 };
166 
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks)167 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks) {
168   ASSERT_EQ(control_callback->net_acquired_count, 0);
169 
170   {
171     PromiseFutureContext context(acquire_promise, true);
172     auto status = WakelockNative::Get().Acquire(kTestWakelockName);
173     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
174   }
175   ASSERT_EQ(control_callback->net_acquired_count, 1);
176 
177   {
178     PromiseFutureContext context(release_promise, true);
179     auto status = WakelockNative::Get().Release(kTestWakelockName);
180     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
181   }
182   ASSERT_EQ(control_callback->net_acquired_count, 0);
183 }
184 
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_repeated_acquire)185 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_repeated_acquire) {
186   ASSERT_EQ(control_callback->net_acquired_count, 0);
187 
188   {
189     PromiseFutureContext context(acquire_promise, true);
190     auto status = WakelockNative::Get().Acquire(kTestWakelockName);
191     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
192   }
193   ASSERT_EQ(control_callback->net_acquired_count, 1);
194 
195   {
196     PromiseFutureContext context(acquire_promise, false);
197     auto status = WakelockNative::Get().Acquire(kTestWakelockName);
198     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
199   }
200   ASSERT_EQ(control_callback->net_acquired_count, 1);
201 
202   {
203     PromiseFutureContext context(release_promise, true);
204     auto status = WakelockNative::Get().Release(kTestWakelockName);
205     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
206   }
207   ASSERT_EQ(control_callback->net_acquired_count, 0);
208 }
209 
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_repeated_release)210 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_repeated_release) {
211   ASSERT_EQ(control_callback->net_acquired_count, 0);
212 
213   {
214     PromiseFutureContext context(acquire_promise, true);
215     auto status = WakelockNative::Get().Acquire(kTestWakelockName);
216     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
217   }
218   ASSERT_EQ(control_callback->net_acquired_count, 1);
219 
220   {
221     PromiseFutureContext context(release_promise, true);
222     auto status = WakelockNative::Get().Release(kTestWakelockName);
223     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
224   }
225   ASSERT_EQ(control_callback->net_acquired_count, 0);
226 
227   {
228     PromiseFutureContext context(release_promise, false);
229     auto status = WakelockNative::Get().Release(kTestWakelockName);
230     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
231   }
232   ASSERT_EQ(control_callback->net_acquired_count, 0);
233 }
234 
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_in_a_loop)235 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_in_a_loop) {
236   ASSERT_EQ(control_callback->net_acquired_count, 0);
237 
238   for (int i = 0; i < 10; ++i) {
239     {
240       PromiseFutureContext context(acquire_promise, true);
241       auto status = WakelockNative::Get().Acquire(kTestWakelockName);
242       ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
243     }
244     ASSERT_EQ(control_callback->net_acquired_count, 1);
245 
246     {
247       PromiseFutureContext context(release_promise, true);
248       auto status = WakelockNative::Get().Release(kTestWakelockName);
249       ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
250     }
251     ASSERT_EQ(control_callback->net_acquired_count, 0);
252   }
253 }
254 
TEST_F(WakelockNativeTest,test_clean_up)255 TEST_F(WakelockNativeTest, test_clean_up) {
256   WakelockNative::Get().Initialize();
257   ASSERT_EQ(control_callback->net_acquired_count, 0);
258 
259   {
260     PromiseFutureContext context(acquire_promise, true);
261     auto status = WakelockNative::Get().Acquire(kTestWakelockName);
262     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
263   }
264   ASSERT_EQ(control_callback->net_acquired_count, 1);
265 
266   {
267     PromiseFutureContext context(release_promise, true);
268     WakelockNative::Get().CleanUp();
269   }
270   ASSERT_EQ(control_callback->net_acquired_count, 0);
271 }
272 
273 }  // namespace testing