1 /*
2 * Copyright 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 * binder_test.cpp - unit tests for netd binder RPCs.
17 */
18
19 #include <cerrno>
20 #include <cinttypes>
21 #include <cstdint>
22 #include <cstdio>
23 #include <cstdlib>
24 #include <set>
25 #include <vector>
26
27 #include <sys/socket.h>
28 #include <netinet/in.h>
29
30 #include <android-base/stringprintf.h>
31 #include <android-base/strings.h>
32 #include <cutils/multiuser.h>
33 #include <gtest/gtest.h>
34 #include <logwrap/logwrap.h>
35
36 #include "NetdConstants.h"
37 #include "android/net/INetd.h"
38 #include "android/net/UidRange.h"
39 #include "binder/IServiceManager.h"
40
41 using namespace android;
42 using namespace android::base;
43 using namespace android::binder;
44 using android::net::INetd;
45 using android::net::UidRange;
46
47 static const char* IP_RULE_V4 = "-4";
48 static const char* IP_RULE_V6 = "-6";
49
50 class BinderTest : public ::testing::Test {
51
52 public:
BinderTest()53 BinderTest() {
54 sp<IServiceManager> sm = defaultServiceManager();
55 sp<IBinder> binder = sm->getService(String16("netd"));
56 if (binder != nullptr) {
57 mNetd = interface_cast<INetd>(binder);
58 }
59 }
60
SetUp()61 void SetUp() {
62 ASSERT_NE(nullptr, mNetd.get());
63 }
64
65 protected:
66 sp<INetd> mNetd;
67 };
68
69
70 class TimedOperation : public Stopwatch {
71 public:
TimedOperation(std::string name)72 TimedOperation(std::string name): mName(name) {}
~TimedOperation()73 virtual ~TimedOperation() {
74 fprintf(stderr, " %s: %6.1f ms\n", mName.c_str(), timeTaken());
75 }
76
77 private:
78 std::string mName;
79 };
80
TEST_F(BinderTest,TestIsAlive)81 TEST_F(BinderTest, TestIsAlive) {
82 TimedOperation t("isAlive RPC");
83 bool isAlive = false;
84 mNetd->isAlive(&isAlive);
85 ASSERT_TRUE(isAlive);
86 }
87
randomUid()88 static int randomUid() {
89 return 100000 * arc4random_uniform(7) + 10000 + arc4random_uniform(5000);
90 }
91
runCommand(const std::string & command)92 static std::vector<std::string> runCommand(const std::string& command) {
93 std::vector<std::string> lines;
94 FILE *f;
95
96 if ((f = popen(command.c_str(), "r")) == nullptr) {
97 perror("popen");
98 return lines;
99 }
100
101 char *line = nullptr;
102 size_t bufsize = 0;
103 ssize_t linelen = 0;
104 while ((linelen = getline(&line, &bufsize, f)) >= 0) {
105 lines.push_back(std::string(line, linelen));
106 free(line);
107 line = nullptr;
108 }
109
110 pclose(f);
111 return lines;
112 }
113
listIpRules(const char * ipVersion)114 static std::vector<std::string> listIpRules(const char *ipVersion) {
115 std::string command = StringPrintf("%s %s rule list", IP_PATH, ipVersion);
116 return runCommand(command);
117 }
118
listIptablesRule(const char * binary,const char * chainName)119 static std::vector<std::string> listIptablesRule(const char *binary, const char *chainName) {
120 std::string command = StringPrintf("%s -n -L %s", binary, chainName);
121 return runCommand(command);
122 }
123
iptablesRuleLineLength(const char * binary,const char * chainName)124 static int iptablesRuleLineLength(const char *binary, const char *chainName) {
125 return listIptablesRule(binary, chainName).size();
126 }
127
TEST_F(BinderTest,TestFirewallReplaceUidChain)128 TEST_F(BinderTest, TestFirewallReplaceUidChain) {
129 std::string chainName = StringPrintf("netd_binder_test_%u", arc4random_uniform(10000));
130 const int kNumUids = 500;
131 std::vector<int32_t> noUids(0);
132 std::vector<int32_t> uids(kNumUids);
133 for (int i = 0; i < kNumUids; i++) {
134 uids[i] = randomUid();
135 }
136
137 bool ret;
138 {
139 TimedOperation op(StringPrintf("Programming %d-UID whitelist chain", kNumUids));
140 mNetd->firewallReplaceUidChain(String16(chainName.c_str()), true, uids, &ret);
141 }
142 EXPECT_EQ(true, ret);
143 EXPECT_EQ((int) uids.size() + 5, iptablesRuleLineLength(IPTABLES_PATH, chainName.c_str()));
144 EXPECT_EQ((int) uids.size() + 11, iptablesRuleLineLength(IP6TABLES_PATH, chainName.c_str()));
145 {
146 TimedOperation op("Clearing whitelist chain");
147 mNetd->firewallReplaceUidChain(String16(chainName.c_str()), false, noUids, &ret);
148 }
149 EXPECT_EQ(true, ret);
150 EXPECT_EQ(3, iptablesRuleLineLength(IPTABLES_PATH, chainName.c_str()));
151 EXPECT_EQ(3, iptablesRuleLineLength(IP6TABLES_PATH, chainName.c_str()));
152
153 {
154 TimedOperation op(StringPrintf("Programming %d-UID blacklist chain", kNumUids));
155 mNetd->firewallReplaceUidChain(String16(chainName.c_str()), false, uids, &ret);
156 }
157 EXPECT_EQ(true, ret);
158 EXPECT_EQ((int) uids.size() + 3, iptablesRuleLineLength(IPTABLES_PATH, chainName.c_str()));
159 EXPECT_EQ((int) uids.size() + 3, iptablesRuleLineLength(IP6TABLES_PATH, chainName.c_str()));
160
161 {
162 TimedOperation op("Clearing blacklist chain");
163 mNetd->firewallReplaceUidChain(String16(chainName.c_str()), false, noUids, &ret);
164 }
165 EXPECT_EQ(true, ret);
166 EXPECT_EQ(3, iptablesRuleLineLength(IPTABLES_PATH, chainName.c_str()));
167 EXPECT_EQ(3, iptablesRuleLineLength(IP6TABLES_PATH, chainName.c_str()));
168
169 // Check that the call fails if iptables returns an error.
170 std::string veryLongStringName = "netd_binder_test_UnacceptablyLongIptablesChainName";
171 mNetd->firewallReplaceUidChain(String16(veryLongStringName.c_str()), true, noUids, &ret);
172 EXPECT_EQ(false, ret);
173 }
174
bandwidthDataSaverEnabled(const char * binary)175 static int bandwidthDataSaverEnabled(const char *binary) {
176 std::vector<std::string> lines = listIptablesRule(binary, "bw_data_saver");
177
178 // Output looks like this:
179 //
180 // Chain bw_data_saver (1 references)
181 // target prot opt source destination
182 // RETURN all -- 0.0.0.0/0 0.0.0.0/0
183 EXPECT_EQ(3U, lines.size());
184 if (lines.size() != 3) return -1;
185
186 EXPECT_TRUE(android::base::StartsWith(lines[2], "RETURN ") ||
187 android::base::StartsWith(lines[2], "REJECT "));
188
189 return android::base::StartsWith(lines[2], "REJECT");
190 }
191
enableDataSaver(sp<INetd> & netd,bool enable)192 bool enableDataSaver(sp<INetd>& netd, bool enable) {
193 TimedOperation op(enable ? " Enabling data saver" : "Disabling data saver");
194 bool ret;
195 netd->bandwidthEnableDataSaver(enable, &ret);
196 return ret;
197 }
198
getDataSaverState()199 int getDataSaverState() {
200 const int enabled4 = bandwidthDataSaverEnabled(IPTABLES_PATH);
201 const int enabled6 = bandwidthDataSaverEnabled(IP6TABLES_PATH);
202 EXPECT_EQ(enabled4, enabled6);
203 EXPECT_NE(-1, enabled4);
204 EXPECT_NE(-1, enabled6);
205 if (enabled4 != enabled6 || (enabled6 != 0 && enabled6 != 1)) {
206 return -1;
207 }
208 return enabled6;
209 }
210
TEST_F(BinderTest,TestBandwidthEnableDataSaver)211 TEST_F(BinderTest, TestBandwidthEnableDataSaver) {
212 const int wasEnabled = getDataSaverState();
213 ASSERT_NE(-1, wasEnabled);
214
215 if (wasEnabled) {
216 ASSERT_TRUE(enableDataSaver(mNetd, false));
217 EXPECT_EQ(0, getDataSaverState());
218 }
219
220 ASSERT_TRUE(enableDataSaver(mNetd, false));
221 EXPECT_EQ(0, getDataSaverState());
222
223 ASSERT_TRUE(enableDataSaver(mNetd, true));
224 EXPECT_EQ(1, getDataSaverState());
225
226 ASSERT_TRUE(enableDataSaver(mNetd, true));
227 EXPECT_EQ(1, getDataSaverState());
228
229 if (!wasEnabled) {
230 ASSERT_TRUE(enableDataSaver(mNetd, false));
231 EXPECT_EQ(0, getDataSaverState());
232 }
233 }
234
ipRuleExistsForRange(const uint32_t priority,const UidRange & range,const std::string & action,const char * ipVersion)235 static bool ipRuleExistsForRange(const uint32_t priority, const UidRange& range,
236 const std::string& action, const char* ipVersion) {
237 // Output looks like this:
238 // "12500:\tfrom all fwmark 0x0/0x20000 iif lo uidrange 1000-2000 prohibit"
239 std::vector<std::string> rules = listIpRules(ipVersion);
240
241 std::string prefix = StringPrintf("%" PRIu32 ":", priority);
242 std::string suffix = StringPrintf(" iif lo uidrange %d-%d %s\n",
243 range.getStart(), range.getStop(), action.c_str());
244 for (std::string line : rules) {
245 if (android::base::StartsWith(line, prefix.c_str())
246 && android::base::EndsWith(line, suffix.c_str())) {
247 return true;
248 }
249 }
250 return false;
251 }
252
ipRuleExistsForRange(const uint32_t priority,const UidRange & range,const std::string & action)253 static bool ipRuleExistsForRange(const uint32_t priority, const UidRange& range,
254 const std::string& action) {
255 bool existsIp4 = ipRuleExistsForRange(priority, range, action, IP_RULE_V4);
256 bool existsIp6 = ipRuleExistsForRange(priority, range, action, IP_RULE_V6);
257 EXPECT_EQ(existsIp4, existsIp6);
258 return existsIp4;
259 }
260
TEST_F(BinderTest,TestNetworkRejectNonSecureVpn)261 TEST_F(BinderTest, TestNetworkRejectNonSecureVpn) {
262 constexpr uint32_t RULE_PRIORITY = 12500;
263
264 constexpr int baseUid = MULTIUSER_APP_PER_USER_RANGE * 5;
265 std::vector<UidRange> uidRanges = {
266 {baseUid + 150, baseUid + 224},
267 {baseUid + 226, baseUid + 300}
268 };
269
270 const std::vector<std::string> initialRulesV4 = listIpRules(IP_RULE_V4);
271 const std::vector<std::string> initialRulesV6 = listIpRules(IP_RULE_V6);
272
273 // Create two valid rules.
274 ASSERT_TRUE(mNetd->networkRejectNonSecureVpn(true, uidRanges).isOk());
275 EXPECT_EQ(initialRulesV4.size() + 2, listIpRules(IP_RULE_V4).size());
276 EXPECT_EQ(initialRulesV6.size() + 2, listIpRules(IP_RULE_V6).size());
277 for (auto const& range : uidRanges) {
278 EXPECT_TRUE(ipRuleExistsForRange(RULE_PRIORITY, range, "prohibit"));
279 }
280
281 // Remove the rules.
282 ASSERT_TRUE(mNetd->networkRejectNonSecureVpn(false, uidRanges).isOk());
283 EXPECT_EQ(initialRulesV4.size(), listIpRules(IP_RULE_V4).size());
284 EXPECT_EQ(initialRulesV6.size(), listIpRules(IP_RULE_V6).size());
285 for (auto const& range : uidRanges) {
286 EXPECT_FALSE(ipRuleExistsForRange(RULE_PRIORITY, range, "prohibit"));
287 }
288
289 // Fail to remove the rules a second time after they are already deleted.
290 binder::Status status = mNetd->networkRejectNonSecureVpn(false, uidRanges);
291 ASSERT_EQ(binder::Status::EX_SERVICE_SPECIFIC, status.exceptionCode());
292 EXPECT_EQ(ENOENT, status.serviceSpecificErrorCode());
293
294 // All rules should be the same as before.
295 EXPECT_EQ(initialRulesV4, listIpRules(IP_RULE_V4));
296 EXPECT_EQ(initialRulesV6, listIpRules(IP_RULE_V6));
297 }
298
socketpair(int * clientSocket,int * serverSocket,int * acceptedSocket)299 void socketpair(int *clientSocket, int *serverSocket, int *acceptedSocket) {
300 *serverSocket = socket(AF_INET6, SOCK_STREAM, 0);
301 struct sockaddr_in6 server6 = { .sin6_family = AF_INET6 };
302 ASSERT_EQ(0, bind(*serverSocket, (struct sockaddr *) &server6, sizeof(server6)));
303
304 socklen_t addrlen = sizeof(server6);
305 ASSERT_EQ(0, getsockname(*serverSocket, (struct sockaddr *) &server6, &addrlen));
306 ASSERT_EQ(0, listen(*serverSocket, 10));
307
308 *clientSocket = socket(AF_INET6, SOCK_STREAM, 0);
309 struct sockaddr_in6 client6;
310 ASSERT_EQ(0, connect(*clientSocket, (struct sockaddr *) &server6, sizeof(server6)));
311 ASSERT_EQ(0, getsockname(*clientSocket, (struct sockaddr *) &client6, &addrlen));
312
313 *acceptedSocket = accept(*serverSocket, (struct sockaddr *) &server6, &addrlen);
314 ASSERT_NE(-1, *acceptedSocket);
315
316 ASSERT_EQ(0, memcmp(&client6, &server6, sizeof(client6)));
317 }
318
checkSocketpairOpen(int clientSocket,int acceptedSocket)319 void checkSocketpairOpen(int clientSocket, int acceptedSocket) {
320 char buf[4096];
321 EXPECT_EQ(4, write(clientSocket, "foo", sizeof("foo")));
322 EXPECT_EQ(4, read(acceptedSocket, buf, sizeof(buf)));
323 EXPECT_EQ(0, memcmp(buf, "foo", sizeof("foo")));
324 }
325
checkSocketpairClosed(int clientSocket,int acceptedSocket)326 void checkSocketpairClosed(int clientSocket, int acceptedSocket) {
327 // Check that the client socket was closed with ECONNABORTED.
328 int ret = write(clientSocket, "foo", sizeof("foo"));
329 int err = errno;
330 EXPECT_EQ(-1, ret);
331 EXPECT_EQ(ECONNABORTED, err);
332
333 // Check that it sent a RST to the server.
334 ret = write(acceptedSocket, "foo", sizeof("foo"));
335 err = errno;
336 EXPECT_EQ(-1, ret);
337 EXPECT_EQ(ECONNRESET, err);
338 }
339
TEST_F(BinderTest,TestSocketDestroy)340 TEST_F(BinderTest, TestSocketDestroy) {
341 int clientSocket, serverSocket, acceptedSocket;
342 ASSERT_NO_FATAL_FAILURE(socketpair(&clientSocket, &serverSocket, &acceptedSocket));
343
344 // Pick a random UID in the system UID range.
345 constexpr int baseUid = AID_APP - 2000;
346 static_assert(baseUid > 0, "Not enough UIDs? Please fix this test.");
347 int uid = baseUid + 500 + arc4random_uniform(1000);
348 EXPECT_EQ(0, fchown(clientSocket, uid, -1));
349
350 // UID ranges that don't contain uid.
351 std::vector<UidRange> uidRanges = {
352 {baseUid + 42, baseUid + 449},
353 {baseUid + 1536, AID_APP - 4},
354 {baseUid + 498, uid - 1},
355 {uid + 1, baseUid + 1520},
356 };
357 // A skip list that doesn't contain UID.
358 std::vector<int32_t> skipUids { baseUid + 123, baseUid + 1600 };
359
360 // Close sockets. Our test socket should be intact.
361 EXPECT_TRUE(mNetd->socketDestroy(uidRanges, skipUids).isOk());
362 checkSocketpairOpen(clientSocket, acceptedSocket);
363
364 // UID ranges that do contain uid.
365 uidRanges = {
366 {baseUid + 42, baseUid + 449},
367 {baseUid + 1536, AID_APP - 4},
368 {baseUid + 498, baseUid + 1520},
369 };
370 // Add uid to the skip list.
371 skipUids.push_back(uid);
372
373 // Close sockets. Our test socket should still be intact because it's in the skip list.
374 EXPECT_TRUE(mNetd->socketDestroy(uidRanges, skipUids).isOk());
375 checkSocketpairOpen(clientSocket, acceptedSocket);
376
377 // Now remove uid from skipUids, and close sockets. Our test socket should have been closed.
378 skipUids.resize(skipUids.size() - 1);
379 EXPECT_TRUE(mNetd->socketDestroy(uidRanges, skipUids).isOk());
380 checkSocketpairClosed(clientSocket, acceptedSocket);
381
382 close(clientSocket);
383 close(serverSocket);
384 close(acceptedSocket);
385 }
386