1 /*
2  * Copyright 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "async_fd_watcher_unittest"
18 
19 #include "async_fd_watcher.h"
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include <log/log.h>
24 #include <netdb.h>
25 #include <netinet/in.h>
26 #include <sys/socket.h>
27 #include <sys/types.h>
28 #include <unistd.h>
29 
30 #include <cstdint>
31 #include <cstring>
32 #include <vector>
33 
34 namespace android::hardware::bluetooth::async_test {
35 
36 using android::hardware::bluetooth::async::AsyncFdWatcher;
37 
38 class AsyncFdWatcherSocketTest : public ::testing::Test {
39  public:
40   static const uint16_t kPort = 6111;
41   static const size_t kBufferSize = 16;
42 
CheckBufferEquals()43   bool CheckBufferEquals() {
44     return strcmp(server_buffer_, client_buffer_) == 0;
45   }
46 
47  protected:
StartServer()48   int StartServer() {
49     ALOGD("%s", __func__);
50     struct sockaddr_in serv_addr;
51     int fd = socket(AF_INET, SOCK_STREAM, 0);
52     EXPECT_FALSE(fd < 0);
53 
54     memset(&serv_addr, 0, sizeof(serv_addr));
55     serv_addr.sin_family = AF_INET;
56     serv_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
57     serv_addr.sin_port = htons(kPort);
58     int reuse_flag = 1;
59     EXPECT_FALSE(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_flag,
60                             sizeof(reuse_flag)) < 0);
61     EXPECT_FALSE(bind(fd, (sockaddr*)&serv_addr, sizeof(serv_addr)) < 0);
62 
63     ALOGD("%s before listen", __func__);
64     listen(fd, 1);
65     return fd;
66   }
67 
AcceptConnection(int fd)68   int AcceptConnection(int fd) {
69     ALOGD("%s", __func__);
70     struct sockaddr_in cli_addr;
71     memset(&cli_addr, 0, sizeof(cli_addr));
72     socklen_t clilen = sizeof(cli_addr);
73 
74     int connection_fd = accept(fd, (struct sockaddr*)&cli_addr, &clilen);
75     EXPECT_FALSE(connection_fd < 0);
76 
77     return connection_fd;
78   }
79 
ReadIncomingMessage(int fd)80   void ReadIncomingMessage(int fd) {
81     ALOGD("%s", __func__);
82     int n = TEMP_FAILURE_RETRY(read(fd, server_buffer_, kBufferSize - 1));
83     EXPECT_FALSE(n < 0);
84 
85     if (n == 0) {  // got EOF
86       ALOGD("%s: EOF", __func__);
87     } else {
88       ALOGD("%s: Got something", __func__);
89       n = write(fd, "1", 1);
90     }
91   }
92 
SetUp()93   void SetUp() override {
94     ALOGD("%s", __func__);
95     memset(server_buffer_, 0, kBufferSize);
96     memset(client_buffer_, 0, kBufferSize);
97   }
98 
ConfigureServer()99   void ConfigureServer() {
100     socket_fd_ = StartServer();
101 
102     conn_watcher_.WatchFdForNonBlockingReads(socket_fd_, [this](int fd) {
103       int connection_fd = AcceptConnection(fd);
104       ALOGD("%s: Conn_watcher fd = %d", __func__, fd);
105 
106       conn_watcher_.ConfigureTimeout(std::chrono::seconds(0), []() {
107         bool connection_timeout_cleared = false;
108         ASSERT_TRUE(connection_timeout_cleared);
109       });
110 
111       ALOGD("%s: 3", __func__);
112       async_fd_watcher_.WatchFdForNonBlockingReads(
113           connection_fd, [this](int fd) { ReadIncomingMessage(fd); });
114 
115       // Time out if it takes longer than a second.
116       SetTimeout(std::chrono::seconds(1));
117     });
118     conn_watcher_.ConfigureTimeout(std::chrono::seconds(1), []() {
119       bool connection_timeout = true;
120       ASSERT_FALSE(connection_timeout);
121     });
122   }
123 
CleanUpServer()124   void CleanUpServer() {
125     async_fd_watcher_.StopWatchingFileDescriptors();
126     conn_watcher_.StopWatchingFileDescriptors();
127     close(socket_fd_);
128   }
129 
TearDown()130   void TearDown() override {
131     ALOGD("%s 3", __func__);
132     EXPECT_TRUE(CheckBufferEquals());
133   }
134 
OnTimeout()135   void OnTimeout() {
136     ALOGD("%s", __func__);
137     timed_out_ = true;
138   }
139 
ClearTimeout()140   void ClearTimeout() {
141     ALOGD("%s", __func__);
142     timed_out_ = false;
143   }
144 
TimedOut()145   bool TimedOut() {
146     ALOGD("%s %d", __func__, timed_out_ ? 1 : 0);
147     return timed_out_;
148   }
149 
SetTimeout(std::chrono::milliseconds timeout_ms)150   void SetTimeout(std::chrono::milliseconds timeout_ms) {
151     ALOGD("%s", __func__);
152     async_fd_watcher_.ConfigureTimeout(timeout_ms, [this]() { OnTimeout(); });
153     ClearTimeout();
154   }
155 
ConnectClient()156   int ConnectClient() {
157     ALOGD("%s", __func__);
158     int socket_cli_fd = socket(AF_INET, SOCK_STREAM, 0);
159     EXPECT_FALSE(socket_cli_fd < 0);
160 
161     struct sockaddr_in serv_addr;
162     memset((void*)&serv_addr, 0, sizeof(serv_addr));
163     serv_addr.sin_family = AF_INET;
164     serv_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
165     serv_addr.sin_port = htons(kPort);
166 
167     int result =
168         connect(socket_cli_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
169     EXPECT_FALSE(result < 0);
170 
171     return socket_cli_fd;
172   }
173 
WriteFromClient(int socket_cli_fd)174   void WriteFromClient(int socket_cli_fd) {
175     ALOGD("%s", __func__);
176     strcpy(client_buffer_, "1");
177     int n = write(socket_cli_fd, client_buffer_, strlen(client_buffer_));
178     EXPECT_TRUE(n > 0);
179   }
180 
AwaitServerResponse(int socket_cli_fd)181   void AwaitServerResponse(int socket_cli_fd) {
182     ALOGD("%s", __func__);
183     int n = read(socket_cli_fd, client_buffer_, 1);
184     ALOGD("%s done", __func__);
185     EXPECT_TRUE(n > 0);
186   }
187 
188  private:
189   AsyncFdWatcher async_fd_watcher_;
190   AsyncFdWatcher conn_watcher_;
191   int socket_fd_;
192   char server_buffer_[kBufferSize];
193   char client_buffer_[kBufferSize];
194   bool timed_out_;
195 };
196 
197 // Use a single AsyncFdWatcher to signal a connection to the server socket.
TEST_F(AsyncFdWatcherSocketTest,Connect)198 TEST_F(AsyncFdWatcherSocketTest, Connect) {
199   int socket_fd = StartServer();
200 
201   AsyncFdWatcher conn_watcher;
202   conn_watcher.WatchFdForNonBlockingReads(socket_fd, [this](int fd) {
203     int connection_fd = AcceptConnection(fd);
204     close(connection_fd);
205   });
206 
207   // Fail if the client doesn't connect within 1 second.
208   conn_watcher.ConfigureTimeout(std::chrono::seconds(1), []() {
209     bool connection_timeout = true;
210     ASSERT_FALSE(connection_timeout);
211   });
212 
213   int socket_cli_fd = ConnectClient();
214   conn_watcher.StopWatchingFileDescriptors();
215   close(socket_fd);
216   close(socket_cli_fd);
217 }
218 
219 // Use a single AsyncFdWatcher to signal a connection to the server socket.
TEST_F(AsyncFdWatcherSocketTest,TimedOutConnect)220 TEST_F(AsyncFdWatcherSocketTest, TimedOutConnect) {
221   int socket_fd = StartServer();
222   bool timed_out = false;
223   bool* timeout_ptr = &timed_out;
224 
225   AsyncFdWatcher conn_watcher;
226   conn_watcher.WatchFdForNonBlockingReads(socket_fd, [this](int fd) {
227     int connection_fd = AcceptConnection(fd);
228     close(connection_fd);
229   });
230 
231   // Set the timeout flag after 100ms.
232   conn_watcher.ConfigureTimeout(std::chrono::milliseconds(100),
233                                 [timeout_ptr]() { *timeout_ptr = true; });
234   EXPECT_FALSE(timed_out);
235   sleep(1);
236   EXPECT_TRUE(timed_out);
237   conn_watcher.StopWatchingFileDescriptors();
238   close(socket_fd);
239 }
240 
241 // Modify the timeout in a timeout callback.
TEST_F(AsyncFdWatcherSocketTest,TimedOutSchedulesTimeout)242 TEST_F(AsyncFdWatcherSocketTest, TimedOutSchedulesTimeout) {
243   int socket_fd = StartServer();
244   bool timed_out = false;
245   bool timed_out2 = false;
246 
247   AsyncFdWatcher conn_watcher;
248   conn_watcher.WatchFdForNonBlockingReads(socket_fd, [this](int fd) {
249     int connection_fd = AcceptConnection(fd);
250     close(connection_fd);
251   });
252 
253   // Set a timeout flag in each callback.
254   conn_watcher.ConfigureTimeout(std::chrono::milliseconds(500),
255                                 [&conn_watcher, &timed_out, &timed_out2]() {
256                                   timed_out = true;
257                                   conn_watcher.ConfigureTimeout(
258                                       std::chrono::seconds(1),
259                                       [&timed_out2]() { timed_out2 = true; });
260                                 });
261   EXPECT_FALSE(timed_out);
262   EXPECT_FALSE(timed_out2);
263   sleep(1);
264   EXPECT_TRUE(timed_out);
265   EXPECT_FALSE(timed_out2);
266   sleep(1);
267   EXPECT_TRUE(timed_out);
268   EXPECT_TRUE(timed_out2);
269   conn_watcher.StopWatchingFileDescriptors();
270   close(socket_fd);
271 }
272 
273 MATCHER_P(ReadAndMatchSingleChar, byte,
274           "Reads a byte from the file descriptor and matches the value against "
275           "byte") {
276   char inbuf[1] = {0};
277 
278   int n = TEMP_FAILURE_RETRY(read(arg, inbuf, 1));
279 
280   TEMP_FAILURE_RETRY(write(arg, inbuf, 1));
281   if (n != 1) {
282     return false;
283   }
284   return inbuf[0] == byte;
285 };
286 
287 // Use a single AsyncFdWatcher to watch two file descriptors.
TEST_F(AsyncFdWatcherSocketTest,WatchTwoFileDescriptors)288 TEST_F(AsyncFdWatcherSocketTest, WatchTwoFileDescriptors) {
289   int sockfd1[2];
290   int sockfd2[2];
291   socketpair(AF_LOCAL, SOCK_STREAM, 0, sockfd1);
292   socketpair(AF_LOCAL, SOCK_STREAM, 0, sockfd2);
293 
294   testing::MockFunction<void(int)> cb1;
295   testing::MockFunction<void(int)> cb2;
296 
297   AsyncFdWatcher watcher;
298   watcher.WatchFdForNonBlockingReads(sockfd1[0], cb1.AsStdFunction());
299 
300   watcher.WatchFdForNonBlockingReads(sockfd2[0], cb2.AsStdFunction());
301 
302   EXPECT_CALL(cb1, Call(ReadAndMatchSingleChar('1')));
303   char one_buf[1] = {'1'};
304   TEMP_FAILURE_RETRY(write(sockfd1[1], one_buf, sizeof(one_buf)));
305 
306   EXPECT_CALL(cb2, Call(ReadAndMatchSingleChar('2')));
307   char two_buf[1] = {'2'};
308   TEMP_FAILURE_RETRY(write(sockfd2[1], two_buf, sizeof(two_buf)));
309 
310   // Blocking read instead of a flush.
311   TEMP_FAILURE_RETRY(read(sockfd1[1], one_buf, sizeof(one_buf)));
312   TEMP_FAILURE_RETRY(read(sockfd2[1], two_buf, sizeof(two_buf)));
313 
314   watcher.StopWatchingFileDescriptors();
315 }
316 
317 // Use two AsyncFdWatchers to set up a server socket.
TEST_F(AsyncFdWatcherSocketTest,ClientServer)318 TEST_F(AsyncFdWatcherSocketTest, ClientServer) {
319   ConfigureServer();
320   int socket_cli_fd = ConnectClient();
321 
322   WriteFromClient(socket_cli_fd);
323 
324   AwaitServerResponse(socket_cli_fd);
325 
326   close(socket_cli_fd);
327   CleanUpServer();
328 }
329 
330 // Use two AsyncFdWatchers to set up a server socket, which times out.
TEST_F(AsyncFdWatcherSocketTest,TimeOutTest)331 TEST_F(AsyncFdWatcherSocketTest, TimeOutTest) {
332   ConfigureServer();
333   int socket_cli_fd = ConnectClient();
334 
335   while (!TimedOut()) sleep(1);
336 
337   close(socket_cli_fd);
338   CleanUpServer();
339 }
340 
341 // Use two AsyncFdWatchers to set up a server socket, which times out.
TEST_F(AsyncFdWatcherSocketTest,RepeatedTimeOutTest)342 TEST_F(AsyncFdWatcherSocketTest, RepeatedTimeOutTest) {
343   ConfigureServer();
344   int socket_cli_fd = ConnectClient();
345   ClearTimeout();
346 
347   // Time out when there are no writes.
348   EXPECT_FALSE(TimedOut());
349   sleep(2);
350   EXPECT_TRUE(TimedOut());
351   ClearTimeout();
352 
353   // Don't time out when there is a write.
354   WriteFromClient(socket_cli_fd);
355   AwaitServerResponse(socket_cli_fd);
356   EXPECT_FALSE(TimedOut());
357   ClearTimeout();
358 
359   // Time out when the write is late.
360   sleep(2);
361   WriteFromClient(socket_cli_fd);
362   AwaitServerResponse(socket_cli_fd);
363   EXPECT_TRUE(TimedOut());
364   ClearTimeout();
365 
366   // Time out when there is a pause after a write.
367   WriteFromClient(socket_cli_fd);
368   sleep(2);
369   AwaitServerResponse(socket_cli_fd);
370   EXPECT_TRUE(TimedOut());
371   ClearTimeout();
372 
373   close(socket_cli_fd);
374   CleanUpServer();
375 }
376 
377 }  // namespace android::hardware::bluetooth::async_test
378