1 /*
2  * Copyright 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  * binder_test.cpp - unit tests for netd binder RPCs.
17  */
18 
19 #include <cerrno>
20 #include <cinttypes>
21 #include <cstdint>
22 #include <cstdio>
23 #include <cstdlib>
24 #include <set>
25 #include <vector>
26 
27 #include <sys/socket.h>
28 #include <netinet/in.h>
29 
30 #include <android-base/stringprintf.h>
31 #include <android-base/strings.h>
32 #include <cutils/multiuser.h>
33 #include <gtest/gtest.h>
34 #include <logwrap/logwrap.h>
35 
36 #include "NetdConstants.h"
37 #include "android/net/INetd.h"
38 #include "android/net/UidRange.h"
39 #include "binder/IServiceManager.h"
40 
41 using namespace android;
42 using namespace android::base;
43 using namespace android::binder;
44 using android::net::INetd;
45 using android::net::UidRange;
46 
47 static const char* IP_RULE_V4 = "-4";
48 static const char* IP_RULE_V6 = "-6";
49 
50 class BinderTest : public ::testing::Test {
51 
52 public:
BinderTest()53     BinderTest() {
54         sp<IServiceManager> sm = defaultServiceManager();
55         sp<IBinder> binder = sm->getService(String16("netd"));
56         if (binder != nullptr) {
57             mNetd = interface_cast<INetd>(binder);
58         }
59     }
60 
SetUp()61     void SetUp() {
62         ASSERT_NE(nullptr, mNetd.get());
63     }
64 
65 protected:
66     sp<INetd> mNetd;
67 };
68 
69 
70 class TimedOperation : public Stopwatch {
71 public:
TimedOperation(std::string name)72     TimedOperation(std::string name): mName(name) {}
~TimedOperation()73     virtual ~TimedOperation() {
74         fprintf(stderr, "    %s: %6.1f ms\n", mName.c_str(), timeTaken());
75     }
76 
77 private:
78     std::string mName;
79 };
80 
TEST_F(BinderTest,TestIsAlive)81 TEST_F(BinderTest, TestIsAlive) {
82     TimedOperation t("isAlive RPC");
83     bool isAlive = false;
84     mNetd->isAlive(&isAlive);
85     ASSERT_TRUE(isAlive);
86 }
87 
randomUid()88 static int randomUid() {
89     return 100000 * arc4random_uniform(7) + 10000 + arc4random_uniform(5000);
90 }
91 
runCommand(const std::string & command)92 static std::vector<std::string> runCommand(const std::string& command) {
93     std::vector<std::string> lines;
94     FILE *f;
95 
96     if ((f = popen(command.c_str(), "r")) == nullptr) {
97         perror("popen");
98         return lines;
99     }
100 
101     char *line = nullptr;
102     size_t bufsize = 0;
103     ssize_t linelen = 0;
104     while ((linelen = getline(&line, &bufsize, f)) >= 0) {
105         lines.push_back(std::string(line, linelen));
106         free(line);
107         line = nullptr;
108     }
109 
110     pclose(f);
111     return lines;
112 }
113 
listIpRules(const char * ipVersion)114 static std::vector<std::string> listIpRules(const char *ipVersion) {
115     std::string command = StringPrintf("%s %s rule list", IP_PATH, ipVersion);
116     return runCommand(command);
117 }
118 
listIptablesRule(const char * binary,const char * chainName)119 static std::vector<std::string> listIptablesRule(const char *binary, const char *chainName) {
120     std::string command = StringPrintf("%s -n -L %s", binary, chainName);
121     return runCommand(command);
122 }
123 
iptablesRuleLineLength(const char * binary,const char * chainName)124 static int iptablesRuleLineLength(const char *binary, const char *chainName) {
125     return listIptablesRule(binary, chainName).size();
126 }
127 
TEST_F(BinderTest,TestFirewallReplaceUidChain)128 TEST_F(BinderTest, TestFirewallReplaceUidChain) {
129     std::string chainName = StringPrintf("netd_binder_test_%u", arc4random_uniform(10000));
130     const int kNumUids = 500;
131     std::vector<int32_t> noUids(0);
132     std::vector<int32_t> uids(kNumUids);
133     for (int i = 0; i < kNumUids; i++) {
134         uids[i] = randomUid();
135     }
136 
137     bool ret;
138     {
139         TimedOperation op(StringPrintf("Programming %d-UID whitelist chain", kNumUids));
140         mNetd->firewallReplaceUidChain(String16(chainName.c_str()), true, uids, &ret);
141     }
142     EXPECT_EQ(true, ret);
143     EXPECT_EQ((int) uids.size() + 5, iptablesRuleLineLength(IPTABLES_PATH, chainName.c_str()));
144     EXPECT_EQ((int) uids.size() + 11, iptablesRuleLineLength(IP6TABLES_PATH, chainName.c_str()));
145     {
146         TimedOperation op("Clearing whitelist chain");
147         mNetd->firewallReplaceUidChain(String16(chainName.c_str()), false, noUids, &ret);
148     }
149     EXPECT_EQ(true, ret);
150     EXPECT_EQ(3, iptablesRuleLineLength(IPTABLES_PATH, chainName.c_str()));
151     EXPECT_EQ(3, iptablesRuleLineLength(IP6TABLES_PATH, chainName.c_str()));
152 
153     {
154         TimedOperation op(StringPrintf("Programming %d-UID blacklist chain", kNumUids));
155         mNetd->firewallReplaceUidChain(String16(chainName.c_str()), false, uids, &ret);
156     }
157     EXPECT_EQ(true, ret);
158     EXPECT_EQ((int) uids.size() + 3, iptablesRuleLineLength(IPTABLES_PATH, chainName.c_str()));
159     EXPECT_EQ((int) uids.size() + 3, iptablesRuleLineLength(IP6TABLES_PATH, chainName.c_str()));
160 
161     {
162         TimedOperation op("Clearing blacklist chain");
163         mNetd->firewallReplaceUidChain(String16(chainName.c_str()), false, noUids, &ret);
164     }
165     EXPECT_EQ(true, ret);
166     EXPECT_EQ(3, iptablesRuleLineLength(IPTABLES_PATH, chainName.c_str()));
167     EXPECT_EQ(3, iptablesRuleLineLength(IP6TABLES_PATH, chainName.c_str()));
168 
169     // Check that the call fails if iptables returns an error.
170     std::string veryLongStringName = "netd_binder_test_UnacceptablyLongIptablesChainName";
171     mNetd->firewallReplaceUidChain(String16(veryLongStringName.c_str()), true, noUids, &ret);
172     EXPECT_EQ(false, ret);
173 }
174 
bandwidthDataSaverEnabled(const char * binary)175 static int bandwidthDataSaverEnabled(const char *binary) {
176     std::vector<std::string> lines = listIptablesRule(binary, "bw_data_saver");
177 
178     // Output looks like this:
179     //
180     // Chain bw_data_saver (1 references)
181     // target     prot opt source               destination
182     // RETURN     all  --  0.0.0.0/0            0.0.0.0/0
183     EXPECT_EQ(3U, lines.size());
184     if (lines.size() != 3) return -1;
185 
186     EXPECT_TRUE(android::base::StartsWith(lines[2], "RETURN ") ||
187                 android::base::StartsWith(lines[2], "REJECT "));
188 
189     return android::base::StartsWith(lines[2], "REJECT");
190 }
191 
enableDataSaver(sp<INetd> & netd,bool enable)192 bool enableDataSaver(sp<INetd>& netd, bool enable) {
193     TimedOperation op(enable ? " Enabling data saver" : "Disabling data saver");
194     bool ret;
195     netd->bandwidthEnableDataSaver(enable, &ret);
196     return ret;
197 }
198 
getDataSaverState()199 int getDataSaverState() {
200     const int enabled4 = bandwidthDataSaverEnabled(IPTABLES_PATH);
201     const int enabled6 = bandwidthDataSaverEnabled(IP6TABLES_PATH);
202     EXPECT_EQ(enabled4, enabled6);
203     EXPECT_NE(-1, enabled4);
204     EXPECT_NE(-1, enabled6);
205     if (enabled4 != enabled6 || (enabled6 != 0 && enabled6 != 1)) {
206         return -1;
207     }
208     return enabled6;
209 }
210 
TEST_F(BinderTest,TestBandwidthEnableDataSaver)211 TEST_F(BinderTest, TestBandwidthEnableDataSaver) {
212     const int wasEnabled = getDataSaverState();
213     ASSERT_NE(-1, wasEnabled);
214 
215     if (wasEnabled) {
216         ASSERT_TRUE(enableDataSaver(mNetd, false));
217         EXPECT_EQ(0, getDataSaverState());
218     }
219 
220     ASSERT_TRUE(enableDataSaver(mNetd, false));
221     EXPECT_EQ(0, getDataSaverState());
222 
223     ASSERT_TRUE(enableDataSaver(mNetd, true));
224     EXPECT_EQ(1, getDataSaverState());
225 
226     ASSERT_TRUE(enableDataSaver(mNetd, true));
227     EXPECT_EQ(1, getDataSaverState());
228 
229     if (!wasEnabled) {
230         ASSERT_TRUE(enableDataSaver(mNetd, false));
231         EXPECT_EQ(0, getDataSaverState());
232     }
233 }
234 
ipRuleExistsForRange(const uint32_t priority,const UidRange & range,const std::string & action,const char * ipVersion)235 static bool ipRuleExistsForRange(const uint32_t priority, const UidRange& range,
236         const std::string& action, const char* ipVersion) {
237     // Output looks like this:
238     //   "12500:\tfrom all fwmark 0x0/0x20000 iif lo uidrange 1000-2000 prohibit"
239     std::vector<std::string> rules = listIpRules(ipVersion);
240 
241     std::string prefix = StringPrintf("%" PRIu32 ":", priority);
242     std::string suffix = StringPrintf(" iif lo uidrange %d-%d %s\n",
243             range.getStart(), range.getStop(), action.c_str());
244     for (std::string line : rules) {
245         if (android::base::StartsWith(line, prefix.c_str())
246                 && android::base::EndsWith(line, suffix.c_str())) {
247             return true;
248         }
249     }
250     return false;
251 }
252 
ipRuleExistsForRange(const uint32_t priority,const UidRange & range,const std::string & action)253 static bool ipRuleExistsForRange(const uint32_t priority, const UidRange& range,
254         const std::string& action) {
255     bool existsIp4 = ipRuleExistsForRange(priority, range, action, IP_RULE_V4);
256     bool existsIp6 = ipRuleExistsForRange(priority, range, action, IP_RULE_V6);
257     EXPECT_EQ(existsIp4, existsIp6);
258     return existsIp4;
259 }
260 
TEST_F(BinderTest,TestNetworkRejectNonSecureVpn)261 TEST_F(BinderTest, TestNetworkRejectNonSecureVpn) {
262     constexpr uint32_t RULE_PRIORITY = 12500;
263 
264     constexpr int baseUid = MULTIUSER_APP_PER_USER_RANGE * 5;
265     std::vector<UidRange> uidRanges = {
266         {baseUid + 150, baseUid + 224},
267         {baseUid + 226, baseUid + 300}
268     };
269 
270     const std::vector<std::string> initialRulesV4 = listIpRules(IP_RULE_V4);
271     const std::vector<std::string> initialRulesV6 = listIpRules(IP_RULE_V6);
272 
273     // Create two valid rules.
274     ASSERT_TRUE(mNetd->networkRejectNonSecureVpn(true, uidRanges).isOk());
275     EXPECT_EQ(initialRulesV4.size() + 2, listIpRules(IP_RULE_V4).size());
276     EXPECT_EQ(initialRulesV6.size() + 2, listIpRules(IP_RULE_V6).size());
277     for (auto const& range : uidRanges) {
278         EXPECT_TRUE(ipRuleExistsForRange(RULE_PRIORITY, range, "prohibit"));
279     }
280 
281     // Remove the rules.
282     ASSERT_TRUE(mNetd->networkRejectNonSecureVpn(false, uidRanges).isOk());
283     EXPECT_EQ(initialRulesV4.size(), listIpRules(IP_RULE_V4).size());
284     EXPECT_EQ(initialRulesV6.size(), listIpRules(IP_RULE_V6).size());
285     for (auto const& range : uidRanges) {
286         EXPECT_FALSE(ipRuleExistsForRange(RULE_PRIORITY, range, "prohibit"));
287     }
288 
289     // Fail to remove the rules a second time after they are already deleted.
290     binder::Status status = mNetd->networkRejectNonSecureVpn(false, uidRanges);
291     ASSERT_EQ(binder::Status::EX_SERVICE_SPECIFIC, status.exceptionCode());
292     EXPECT_EQ(ENOENT, status.serviceSpecificErrorCode());
293 
294     // All rules should be the same as before.
295     EXPECT_EQ(initialRulesV4, listIpRules(IP_RULE_V4));
296     EXPECT_EQ(initialRulesV6, listIpRules(IP_RULE_V6));
297 }
298 
socketpair(int * clientSocket,int * serverSocket,int * acceptedSocket)299 void socketpair(int *clientSocket, int *serverSocket, int *acceptedSocket) {
300     *serverSocket = socket(AF_INET6, SOCK_STREAM, 0);
301     struct sockaddr_in6 server6 = { .sin6_family = AF_INET6 };
302     ASSERT_EQ(0, bind(*serverSocket, (struct sockaddr *) &server6, sizeof(server6)));
303 
304     socklen_t addrlen = sizeof(server6);
305     ASSERT_EQ(0, getsockname(*serverSocket, (struct sockaddr *) &server6, &addrlen));
306     ASSERT_EQ(0, listen(*serverSocket, 10));
307 
308     *clientSocket = socket(AF_INET6, SOCK_STREAM, 0);
309     struct sockaddr_in6 client6;
310     ASSERT_EQ(0, connect(*clientSocket, (struct sockaddr *) &server6, sizeof(server6)));
311     ASSERT_EQ(0, getsockname(*clientSocket, (struct sockaddr *) &client6, &addrlen));
312 
313     *acceptedSocket = accept(*serverSocket, (struct sockaddr *) &server6, &addrlen);
314     ASSERT_NE(-1, *acceptedSocket);
315 
316     ASSERT_EQ(0, memcmp(&client6, &server6, sizeof(client6)));
317 }
318 
checkSocketpairOpen(int clientSocket,int acceptedSocket)319 void checkSocketpairOpen(int clientSocket, int acceptedSocket) {
320     char buf[4096];
321     EXPECT_EQ(4, write(clientSocket, "foo", sizeof("foo")));
322     EXPECT_EQ(4, read(acceptedSocket, buf, sizeof(buf)));
323     EXPECT_EQ(0, memcmp(buf, "foo", sizeof("foo")));
324 }
325 
checkSocketpairClosed(int clientSocket,int acceptedSocket)326 void checkSocketpairClosed(int clientSocket, int acceptedSocket) {
327     // Check that the client socket was closed with ECONNABORTED.
328     int ret = write(clientSocket, "foo", sizeof("foo"));
329     int err = errno;
330     EXPECT_EQ(-1, ret);
331     EXPECT_EQ(ECONNABORTED, err);
332 
333     // Check that it sent a RST to the server.
334     ret = write(acceptedSocket, "foo", sizeof("foo"));
335     err = errno;
336     EXPECT_EQ(-1, ret);
337     EXPECT_EQ(ECONNRESET, err);
338 }
339 
TEST_F(BinderTest,TestSocketDestroy)340 TEST_F(BinderTest, TestSocketDestroy) {
341     int clientSocket, serverSocket, acceptedSocket;
342     ASSERT_NO_FATAL_FAILURE(socketpair(&clientSocket, &serverSocket, &acceptedSocket));
343 
344     // Pick a random UID in the system UID range.
345     constexpr int baseUid = AID_APP - 2000;
346     static_assert(baseUid > 0, "Not enough UIDs? Please fix this test.");
347     int uid = baseUid + 500 + arc4random_uniform(1000);
348     EXPECT_EQ(0, fchown(clientSocket, uid, -1));
349 
350     // UID ranges that don't contain uid.
351     std::vector<UidRange> uidRanges = {
352         {baseUid + 42, baseUid + 449},
353         {baseUid + 1536, AID_APP - 4},
354         {baseUid + 498, uid - 1},
355         {uid + 1, baseUid + 1520},
356     };
357     // A skip list that doesn't contain UID.
358     std::vector<int32_t> skipUids { baseUid + 123, baseUid + 1600 };
359 
360     // Close sockets. Our test socket should be intact.
361     EXPECT_TRUE(mNetd->socketDestroy(uidRanges, skipUids).isOk());
362     checkSocketpairOpen(clientSocket, acceptedSocket);
363 
364     // UID ranges that do contain uid.
365     uidRanges = {
366         {baseUid + 42, baseUid + 449},
367         {baseUid + 1536, AID_APP - 4},
368         {baseUid + 498, baseUid + 1520},
369     };
370     // Add uid to the skip list.
371     skipUids.push_back(uid);
372 
373     // Close sockets. Our test socket should still be intact because it's in the skip list.
374     EXPECT_TRUE(mNetd->socketDestroy(uidRanges, skipUids).isOk());
375     checkSocketpairOpen(clientSocket, acceptedSocket);
376 
377     // Now remove uid from skipUids, and close sockets. Our test socket should have been closed.
378     skipUids.resize(skipUids.size() - 1);
379     EXPECT_TRUE(mNetd->socketDestroy(uidRanges, skipUids).isOk());
380     checkSocketpairClosed(clientSocket, acceptedSocket);
381 
382     close(clientSocket);
383     close(serverSocket);
384     close(acceptedSocket);
385 }
386