1 /*
2  * Copyright (C) 2014 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "NetdClient.h"
18 
19 #include <arpa/inet.h>
20 #include <errno.h>
21 #include <math.h>
22 #include <sys/socket.h>
23 #include <unistd.h>
24 
25 #include <atomic>
26 
27 #include "Fwmark.h"
28 #include "FwmarkClient.h"
29 #include "FwmarkCommand.h"
30 #include "resolv_netid.h"
31 #include "Stopwatch.h"
32 
33 namespace {
34 
35 std::atomic_uint netIdForProcess(NETID_UNSET);
36 std::atomic_uint netIdForResolv(NETID_UNSET);
37 
38 typedef int (*Accept4FunctionType)(int, sockaddr*, socklen_t*, int);
39 typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
40 typedef int (*SocketFunctionType)(int, int, int);
41 typedef unsigned (*NetIdForResolvFunctionType)(unsigned);
42 
43 // These variables are only modified at startup (when libc.so is loaded) and never afterwards, so
44 // it's okay that they are read later at runtime without a lock.
45 Accept4FunctionType libcAccept4 = 0;
46 ConnectFunctionType libcConnect = 0;
47 SocketFunctionType libcSocket = 0;
48 
closeFdAndSetErrno(int fd,int error)49 int closeFdAndSetErrno(int fd, int error) {
50     close(fd);
51     errno = -error;
52     return -1;
53 }
54 
netdClientAccept4(int sockfd,sockaddr * addr,socklen_t * addrlen,int flags)55 int netdClientAccept4(int sockfd, sockaddr* addr, socklen_t* addrlen, int flags) {
56     int acceptedSocket = libcAccept4(sockfd, addr, addrlen, flags);
57     if (acceptedSocket == -1) {
58         return -1;
59     }
60     int family;
61     if (addr) {
62         family = addr->sa_family;
63     } else {
64         socklen_t familyLen = sizeof(family);
65         if (getsockopt(acceptedSocket, SOL_SOCKET, SO_DOMAIN, &family, &familyLen) == -1) {
66             return closeFdAndSetErrno(acceptedSocket, -errno);
67         }
68     }
69     if (FwmarkClient::shouldSetFwmark(family)) {
70         FwmarkCommand command = {FwmarkCommand::ON_ACCEPT, 0, 0, 0};
71         if (int error = FwmarkClient().send(&command, acceptedSocket, nullptr)) {
72             return closeFdAndSetErrno(acceptedSocket, error);
73         }
74     }
75     return acceptedSocket;
76 }
77 
netdClientConnect(int sockfd,const sockaddr * addr,socklen_t addrlen)78 int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
79     const bool shouldSetFwmark = (sockfd >= 0) && addr
80             && FwmarkClient::shouldSetFwmark(addr->sa_family);
81     if (shouldSetFwmark) {
82         FwmarkCommand command = {FwmarkCommand::ON_CONNECT, 0, 0, 0};
83         if (int error = FwmarkClient().send(&command, sockfd, nullptr)) {
84             errno = -error;
85             return -1;
86         }
87     }
88     // Latency measurement does not include time of sending commands to Fwmark
89     Stopwatch s;
90     const int ret = libcConnect(sockfd, addr, addrlen);
91     // Save errno so it isn't clobbered by sending ON_CONNECT_COMPLETE
92     const int connectErrno = errno;
93     const unsigned latencyMs = lround(s.timeTaken());
94     // Send an ON_CONNECT_COMPLETE command that includes sockaddr and connect latency for reporting
95     if (shouldSetFwmark && FwmarkClient::shouldReportConnectComplete(addr->sa_family)) {
96         FwmarkConnectInfo connectInfo(ret == 0 ? 0 : connectErrno, latencyMs, addr);
97         // TODO: get the netId from the socket mark once we have continuous benchmark runs
98         FwmarkCommand command = {FwmarkCommand::ON_CONNECT_COMPLETE, /* netId (ignored) */ 0,
99                                  /* uid (filled in by the server) */ 0, 0};
100         // Ignore return value since it's only used for logging
101         FwmarkClient().send(&command, sockfd, &connectInfo);
102     }
103     errno = connectErrno;
104     return ret;
105 }
106 
netdClientSocket(int domain,int type,int protocol)107 int netdClientSocket(int domain, int type, int protocol) {
108     int socketFd = libcSocket(domain, type, protocol);
109     if (socketFd == -1) {
110         return -1;
111     }
112     unsigned netId = netIdForProcess;
113     if (netId != NETID_UNSET && FwmarkClient::shouldSetFwmark(domain)) {
114         if (int error = setNetworkForSocket(netId, socketFd)) {
115             return closeFdAndSetErrno(socketFd, error);
116         }
117     }
118     return socketFd;
119 }
120 
getNetworkForResolv(unsigned netId)121 unsigned getNetworkForResolv(unsigned netId) {
122     if (netId != NETID_UNSET) {
123         return netId;
124     }
125     // Special case for DNS-over-TLS bypass; b/72345192 .
126     if ((netIdForResolv & ~NETID_USE_LOCAL_NAMESERVERS) != NETID_UNSET) {
127         return netIdForResolv;
128     }
129     netId = netIdForProcess;
130     if (netId != NETID_UNSET) {
131         return netId;
132     }
133     return netIdForResolv;
134 }
135 
setNetworkForTarget(unsigned netId,std::atomic_uint * target)136 int setNetworkForTarget(unsigned netId, std::atomic_uint* target) {
137     const unsigned requestedNetId = netId;
138     netId &= ~NETID_USE_LOCAL_NAMESERVERS;
139 
140     if (netId == NETID_UNSET) {
141         *target = netId;
142         return 0;
143     }
144     // Verify that we are allowed to use |netId|, by creating a socket and trying to have it marked
145     // with the netId. Call libcSocket() directly; else the socket creation (via netdClientSocket())
146     // might itself cause another check with the fwmark server, which would be wasteful.
147     int socketFd;
148     if (libcSocket) {
149         socketFd = libcSocket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0);
150     } else {
151         socketFd = socket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0);
152     }
153     if (socketFd < 0) {
154         return -errno;
155     }
156     int error = setNetworkForSocket(netId, socketFd);
157     if (!error) {
158         *target = (target == &netIdForResolv) ? requestedNetId : netId;
159     }
160     close(socketFd);
161     return error;
162 }
163 
164 }  // namespace
165 
166 // accept() just calls accept4(..., 0), so there's no need to handle accept() separately.
netdClientInitAccept4(Accept4FunctionType * function)167 extern "C" void netdClientInitAccept4(Accept4FunctionType* function) {
168     if (function && *function) {
169         libcAccept4 = *function;
170         *function = netdClientAccept4;
171     }
172 }
173 
netdClientInitConnect(ConnectFunctionType * function)174 extern "C" void netdClientInitConnect(ConnectFunctionType* function) {
175     if (function && *function) {
176         libcConnect = *function;
177         *function = netdClientConnect;
178     }
179 }
180 
netdClientInitSocket(SocketFunctionType * function)181 extern "C" void netdClientInitSocket(SocketFunctionType* function) {
182     if (function && *function) {
183         libcSocket = *function;
184         *function = netdClientSocket;
185     }
186 }
187 
netdClientInitNetIdForResolv(NetIdForResolvFunctionType * function)188 extern "C" void netdClientInitNetIdForResolv(NetIdForResolvFunctionType* function) {
189     if (function) {
190         *function = getNetworkForResolv;
191     }
192 }
193 
getNetworkForSocket(unsigned * netId,int socketFd)194 extern "C" int getNetworkForSocket(unsigned* netId, int socketFd) {
195     if (!netId || socketFd < 0) {
196         return -EBADF;
197     }
198     Fwmark fwmark;
199     socklen_t fwmarkLen = sizeof(fwmark.intValue);
200     if (getsockopt(socketFd, SOL_SOCKET, SO_MARK, &fwmark.intValue, &fwmarkLen) == -1) {
201         return -errno;
202     }
203     *netId = fwmark.netId;
204     return 0;
205 }
206 
getNetworkForProcess()207 extern "C" unsigned getNetworkForProcess() {
208     return netIdForProcess;
209 }
210 
setNetworkForSocket(unsigned netId,int socketFd)211 extern "C" int setNetworkForSocket(unsigned netId, int socketFd) {
212     if (socketFd < 0) {
213         return -EBADF;
214     }
215     FwmarkCommand command = {FwmarkCommand::SELECT_NETWORK, netId, 0, 0};
216     return FwmarkClient().send(&command, socketFd, nullptr);
217 }
218 
setNetworkForProcess(unsigned netId)219 extern "C" int setNetworkForProcess(unsigned netId) {
220     return setNetworkForTarget(netId, &netIdForProcess);
221 }
222 
setNetworkForResolv(unsigned netId)223 extern "C" int setNetworkForResolv(unsigned netId) {
224     return setNetworkForTarget(netId, &netIdForResolv);
225 }
226 
protectFromVpn(int socketFd)227 extern "C" int protectFromVpn(int socketFd) {
228     if (socketFd < 0) {
229         return -EBADF;
230     }
231     FwmarkCommand command = {FwmarkCommand::PROTECT_FROM_VPN, 0, 0, 0};
232     return FwmarkClient().send(&command, socketFd, nullptr);
233 }
234 
setNetworkForUser(uid_t uid,int socketFd)235 extern "C" int setNetworkForUser(uid_t uid, int socketFd) {
236     if (socketFd < 0) {
237         return -EBADF;
238     }
239     FwmarkCommand command = {FwmarkCommand::SELECT_FOR_USER, 0, uid, 0};
240     return FwmarkClient().send(&command, socketFd, nullptr);
241 }
242 
queryUserAccess(uid_t uid,unsigned netId)243 extern "C" int queryUserAccess(uid_t uid, unsigned netId) {
244     FwmarkCommand command = {FwmarkCommand::QUERY_USER_ACCESS, netId, uid, 0};
245     return FwmarkClient().send(&command, -1, nullptr);
246 }
247 
tagSocket(int socketFd,uint32_t tag,uid_t uid)248 extern "C" int tagSocket(int socketFd, uint32_t tag, uid_t uid) {
249     FwmarkCommand command = {FwmarkCommand::TAG_SOCKET, 0, uid, tag};
250     return FwmarkClient().send(&command, socketFd, nullptr);
251 }
252 
untagSocket(int socketFd)253 extern "C" int untagSocket(int socketFd) {
254     FwmarkCommand command = {FwmarkCommand::UNTAG_SOCKET, 0, 0, 0};
255     return FwmarkClient().send(&command, socketFd, nullptr);
256 }
257 
setCounterSet(uint32_t counterSet,uid_t uid)258 extern "C" int setCounterSet(uint32_t counterSet, uid_t uid) {
259     FwmarkCommand command = {FwmarkCommand::SET_COUNTERSET, 0, uid, counterSet};
260     return FwmarkClient().send(&command, -1, nullptr);
261 }
262 
deleteTagData(uint32_t tag,uid_t uid)263 extern "C" int deleteTagData(uint32_t tag, uid_t uid) {
264     FwmarkCommand command = {FwmarkCommand::DELETE_TAGDATA, 0, uid, tag};
265     return FwmarkClient().send(&command, -1, nullptr);
266 }
267