1 /*
2  * Copyright 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "hal/snoop_logger_socket_thread.h"
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include <netinet/in.h>
22 #include <sys/socket.h>
23 
24 #include <future>
25 
26 #include "common/init_flags.h"
27 #include "hal/snoop_logger_common.h"
28 #include "hal/syscall_wrapper_impl.h"
29 #include "os/utils.h"
30 
31 namespace testing {
32 
33 using bluetooth::hal::SnoopLoggerCommon;
34 using bluetooth::hal::SnoopLoggerSocket;
35 using bluetooth::hal::SnoopLoggerSocketThread;
36 using bluetooth::hal::SyscallWrapperImpl;
37 
38 static constexpr int INVALID_FD = -1;
39 
40 class SnoopLoggerSocketThreadModuleTest : public Test {};
41 
TEST_F(SnoopLoggerSocketThreadModuleTest,socket_start_no_stop_test)42 TEST_F(SnoopLoggerSocketThreadModuleTest, socket_start_no_stop_test) {
43   {
44     SyscallWrapperImpl socket_if;
45     SnoopLoggerSocketThread sls(std::make_unique<SnoopLoggerSocket>(&socket_if));
46     auto thread_start_future = sls.Start();
47     thread_start_future.wait();
48     ASSERT_TRUE(thread_start_future.get());
49   }
50 
51   // Destructor calls Stop();
52 }
53 
TEST_F(SnoopLoggerSocketThreadModuleTest,socket_stop_no_start_test)54 TEST_F(SnoopLoggerSocketThreadModuleTest, socket_stop_no_start_test) {
55   SyscallWrapperImpl socket_if;
56   SnoopLoggerSocketThread sls(std::make_unique<SnoopLoggerSocket>(&socket_if));
57   sls.Stop();
58 
59   ASSERT_FALSE(sls.ThreadIsRunning());
60 }
61 
TEST_F(SnoopLoggerSocketThreadModuleTest,socket_start_stop_test)62 TEST_F(SnoopLoggerSocketThreadModuleTest, socket_start_stop_test) {
63   SyscallWrapperImpl socket_if;
64   SnoopLoggerSocketThread sls(std::make_unique<SnoopLoggerSocket>(&socket_if));
65   auto thread_start_future = sls.Start();
66   thread_start_future.wait();
67   ASSERT_TRUE(thread_start_future.get());
68 
69   sls.Stop();
70 
71   ASSERT_FALSE(sls.ThreadIsRunning());
72 }
73 
TEST_F(SnoopLoggerSocketThreadModuleTest,socket_repeated_start_stop_test)74 TEST_F(SnoopLoggerSocketThreadModuleTest, socket_repeated_start_stop_test) {
75   int repeat = 10;
76   {
77     SyscallWrapperImpl socket_if;
78     SnoopLoggerSocketThread sls(std::make_unique<SnoopLoggerSocket>(&socket_if));
79 
80     for (int i = 0; i < repeat; ++i) {
81       auto thread_start_future = sls.Start();
82       thread_start_future.wait();
83       ASSERT_TRUE(thread_start_future.get());
84 
85       sls.Stop();
86 
87       ASSERT_FALSE(sls.ThreadIsRunning());
88     }
89   }
90 }
91 
TEST_F(SnoopLoggerSocketThreadModuleTest,socket_connect_test)92 TEST_F(SnoopLoggerSocketThreadModuleTest, socket_connect_test) {
93   int ret = 0;
94   SyscallWrapperImpl socket_if;
95   SnoopLoggerSocketThread sls(std::make_unique<SnoopLoggerSocket>(&socket_if));
96   auto thread_start_future = sls.Start();
97   thread_start_future.wait();
98   ASSERT_TRUE(thread_start_future.get());
99 
100   // // Create a TCP socket file descriptor
101   int socket_fd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
102   ASSERT_TRUE(socket_fd != INVALID_FD);
103 
104   struct sockaddr_in addr;
105   addr.sin_family = AF_INET;
106   addr.sin_addr.s_addr = htonl(SnoopLoggerSocket::DEFAULT_LOCALHOST_);
107   addr.sin_port = htons(SnoopLoggerSocket::DEFAULT_LISTEN_PORT_);
108 
109   // Connect to snoop logger socket
110   RUN_NO_INTR(ret = connect(socket_fd, (struct sockaddr*)&addr, sizeof(addr)));
111   ASSERT_TRUE(ret == 0);
112 
113   sls.Stop();
114 
115   ASSERT_FALSE(sls.ThreadIsRunning());
116   close(socket_fd);
117 }
118 
TEST_F(SnoopLoggerSocketThreadModuleTest,socket_connect_disconnect_test)119 TEST_F(SnoopLoggerSocketThreadModuleTest, socket_connect_disconnect_test) {
120   int ret = 0;
121   SyscallWrapperImpl socket_if;
122   SnoopLoggerSocketThread sls(std::make_unique<SnoopLoggerSocket>(&socket_if));
123   auto thread_start_future = sls.Start();
124   thread_start_future.wait();
125   ASSERT_TRUE(thread_start_future.get());
126 
127   // // Create a TCP socket file descriptor
128   int socket_fd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
129   ASSERT_TRUE(socket_fd != INVALID_FD);
130 
131   struct sockaddr_in addr;
132   addr.sin_family = AF_INET;
133   addr.sin_addr.s_addr = htonl(SnoopLoggerSocket::DEFAULT_LOCALHOST_);
134   addr.sin_port = htons(SnoopLoggerSocket::DEFAULT_LISTEN_PORT_);
135 
136   // Connect to snoop logger socket
137   RUN_NO_INTR(ret = connect(socket_fd, (struct sockaddr*)&addr, sizeof(addr)));
138   ASSERT_TRUE(ret == 0);
139 
140   // Close snoop logger socket
141   RUN_NO_INTR(ret = close(socket_fd));
142   ASSERT_TRUE(ret == 0);
143 
144   sls.Stop();
145 
146   ASSERT_FALSE(sls.ThreadIsRunning());
147   close(socket_fd);
148 }
149 
TEST_F(SnoopLoggerSocketThreadModuleTest,socket_send_no_start_test)150 TEST_F(SnoopLoggerSocketThreadModuleTest, socket_send_no_start_test) {
151   SyscallWrapperImpl socket_if;
152   SnoopLoggerSocketThread sls(std::make_unique<SnoopLoggerSocket>(&socket_if));
153 
154   ASSERT_FALSE(sls.ThreadIsRunning());
155 
156   sls.Write(&SnoopLoggerCommon::kBtSnoopFileHeader, sizeof(SnoopLoggerCommon::FileHeaderType));
157 
158   ASSERT_FALSE(sls.ThreadIsRunning());
159 }
160 
TEST_F(SnoopLoggerSocketThreadModuleTest,socket_send_before_connect_test)161 TEST_F(SnoopLoggerSocketThreadModuleTest, socket_send_before_connect_test) {
162   int ret = 0;
163   SyscallWrapperImpl socket_if;
164   SnoopLoggerSocketThread sls(std::make_unique<SnoopLoggerSocket>(&socket_if));
165   auto thread_start_future = sls.Start();
166   thread_start_future.wait();
167   ASSERT_TRUE(thread_start_future.get());
168 
169   char test_data[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0f};
170   sls.Write(test_data, sizeof(test_data));
171 
172   // // Create a TCP socket file descriptor
173   int socket_fd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
174   ASSERT_TRUE(socket_fd != INVALID_FD);
175 
176   struct sockaddr_in addr;
177   addr.sin_family = AF_INET;
178   addr.sin_addr.s_addr = htonl(SnoopLoggerSocket::DEFAULT_LOCALHOST_);
179   addr.sin_port = htons(SnoopLoggerSocket::DEFAULT_LISTEN_PORT_);
180 
181   // Connect to snoop logger socket
182   RUN_NO_INTR(ret = connect(socket_fd, (struct sockaddr*)&addr, sizeof(addr)));
183   ASSERT_TRUE(ret == 0);
184 
185   char recv_buf1[sizeof(SnoopLoggerCommon::FileHeaderType)];
186   char recv_buf2[sizeof(test_data)];
187   int bytes_read = -1;
188 
189   auto a = std::async(std::launch::async, [socket_fd, &recv_buf1, &recv_buf2] {
190     recv(socket_fd, recv_buf1, sizeof(recv_buf1), 0);
191     return recv(socket_fd, recv_buf2, sizeof(recv_buf2), MSG_DONTWAIT);
192   });
193 
194   sls.GetSocket()->WaitForClientSocketConnected();
195   a.wait();
196   bytes_read = a.get();
197   ASSERT_EQ(bytes_read, -1);
198   close(socket_fd);
199 }
200 
TEST_F(SnoopLoggerSocketThreadModuleTest,socket_recv_file_header_test)201 TEST_F(SnoopLoggerSocketThreadModuleTest, socket_recv_file_header_test) {
202   int ret = 0;
203   SyscallWrapperImpl socket_if;
204   SnoopLoggerSocketThread sls(std::make_unique<SnoopLoggerSocket>(&socket_if));
205   auto thread_start_future = sls.Start();
206   thread_start_future.wait();
207   ASSERT_TRUE(thread_start_future.get());
208 
209   // // Create a TCP socket file descriptor
210   int socket_fd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
211   ASSERT_TRUE(socket_fd != INVALID_FD);
212 
213   struct sockaddr_in addr;
214   addr.sin_family = AF_INET;
215   addr.sin_addr.s_addr = htonl(SnoopLoggerSocket::DEFAULT_LOCALHOST_);
216   addr.sin_port = htons(SnoopLoggerSocket::DEFAULT_LISTEN_PORT_);
217 
218   // Connect to snoop logger socket
219   RUN_NO_INTR(ret = connect(socket_fd, (struct sockaddr*)&addr, sizeof(addr)));
220   ASSERT_TRUE(ret == 0);
221 
222   char recv_buf[sizeof(SnoopLoggerCommon::FileHeaderType)];
223   int bytes_read = -1;
224 
225   auto a = std::async(std::launch::async, [socket_fd, &recv_buf] {
226     return recv(socket_fd, recv_buf, sizeof(SnoopLoggerCommon::FileHeaderType), 0);
227   });
228 
229   sls.GetSocket()->WaitForClientSocketConnected();
230 
231   a.wait();
232   bytes_read = a.get();
233 
234   ASSERT_EQ(bytes_read, static_cast<int>(sizeof(SnoopLoggerCommon::FileHeaderType)));
235   ASSERT_TRUE(std::memcmp(recv_buf, &SnoopLoggerCommon::kBtSnoopFileHeader, bytes_read) == 0);
236   close(socket_fd);
237 }
238 
TEST_F(SnoopLoggerSocketThreadModuleTest,socket_send_recv_test)239 TEST_F(SnoopLoggerSocketThreadModuleTest, socket_send_recv_test) {
240   int ret = 0;
241   SyscallWrapperImpl socket_if;
242   SnoopLoggerSocketThread sls(std::make_unique<SnoopLoggerSocket>(&socket_if));
243   auto thread_start_future = sls.Start();
244   thread_start_future.wait();
245   ASSERT_TRUE(thread_start_future.get());
246 
247   // // Create a TCP socket file descriptor
248   int socket_fd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
249   ASSERT_TRUE(socket_fd != INVALID_FD);
250 
251   struct sockaddr_in addr;
252   addr.sin_family = AF_INET;
253   addr.sin_addr.s_addr = htonl(SnoopLoggerSocket::DEFAULT_LOCALHOST_);
254   addr.sin_port = htons(SnoopLoggerSocket::DEFAULT_LISTEN_PORT_);
255 
256   // Connect to snoop logger socket
257   RUN_NO_INTR(ret = connect(socket_fd, (struct sockaddr*)&addr, sizeof(addr)));
258   ASSERT_TRUE(ret == 0);
259 
260   char test_data[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0f};
261 
262   char recv_buf1[sizeof(SnoopLoggerCommon::FileHeaderType)];
263   char recv_buf2[sizeof(test_data)];
264   int bytes_read = -1;
265 
266   auto a = std::async(std::launch::async, [socket_fd, &recv_buf1, &recv_buf2] {
267     recv(socket_fd, recv_buf1, sizeof(recv_buf1), 0);
268     return recv(socket_fd, recv_buf2, sizeof(recv_buf2), 0);
269   });
270 
271   sls.GetSocket()->WaitForClientSocketConnected();
272 
273   sls.Write(test_data, sizeof(test_data));
274   a.wait();
275   bytes_read = a.get();
276 
277   ASSERT_TRUE(std::memcmp(recv_buf1, &SnoopLoggerCommon::kBtSnoopFileHeader, sizeof(recv_buf1)) == 0);
278 
279   ASSERT_EQ(bytes_read, static_cast<int>(sizeof(test_data)));
280   ASSERT_TRUE(std::memcmp(recv_buf2, test_data, bytes_read) == 0);
281   close(socket_fd);
282 }
283 
284 }  // namespace testing
285