1 /*
2 * Copyright (C) 2014 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "NetdClient.h"
18
19 #include <arpa/inet.h>
20 #include <errno.h>
21 #include <math.h>
22 #include <sys/socket.h>
23 #include <unistd.h>
24
25 #include <atomic>
26
27 #include "Fwmark.h"
28 #include "FwmarkClient.h"
29 #include "FwmarkCommand.h"
30 #include "resolv_netid.h"
31 #include "Stopwatch.h"
32
33 namespace {
34
35 std::atomic_uint netIdForProcess(NETID_UNSET);
36 std::atomic_uint netIdForResolv(NETID_UNSET);
37
38 typedef int (*Accept4FunctionType)(int, sockaddr*, socklen_t*, int);
39 typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
40 typedef int (*SocketFunctionType)(int, int, int);
41 typedef unsigned (*NetIdForResolvFunctionType)(unsigned);
42
43 // These variables are only modified at startup (when libc.so is loaded) and never afterwards, so
44 // it's okay that they are read later at runtime without a lock.
45 Accept4FunctionType libcAccept4 = 0;
46 ConnectFunctionType libcConnect = 0;
47 SocketFunctionType libcSocket = 0;
48
closeFdAndSetErrno(int fd,int error)49 int closeFdAndSetErrno(int fd, int error) {
50 close(fd);
51 errno = -error;
52 return -1;
53 }
54
netdClientAccept4(int sockfd,sockaddr * addr,socklen_t * addrlen,int flags)55 int netdClientAccept4(int sockfd, sockaddr* addr, socklen_t* addrlen, int flags) {
56 int acceptedSocket = libcAccept4(sockfd, addr, addrlen, flags);
57 if (acceptedSocket == -1) {
58 return -1;
59 }
60 int family;
61 if (addr) {
62 family = addr->sa_family;
63 } else {
64 socklen_t familyLen = sizeof(family);
65 if (getsockopt(acceptedSocket, SOL_SOCKET, SO_DOMAIN, &family, &familyLen) == -1) {
66 return closeFdAndSetErrno(acceptedSocket, -errno);
67 }
68 }
69 if (FwmarkClient::shouldSetFwmark(family)) {
70 FwmarkCommand command = {FwmarkCommand::ON_ACCEPT, 0, 0, 0};
71 if (int error = FwmarkClient().send(&command, acceptedSocket, nullptr)) {
72 return closeFdAndSetErrno(acceptedSocket, error);
73 }
74 }
75 return acceptedSocket;
76 }
77
netdClientConnect(int sockfd,const sockaddr * addr,socklen_t addrlen)78 int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
79 const bool shouldSetFwmark = (sockfd >= 0) && addr
80 && FwmarkClient::shouldSetFwmark(addr->sa_family);
81 if (shouldSetFwmark) {
82 FwmarkCommand command = {FwmarkCommand::ON_CONNECT, 0, 0, 0};
83 if (int error = FwmarkClient().send(&command, sockfd, nullptr)) {
84 errno = -error;
85 return -1;
86 }
87 }
88 // Latency measurement does not include time of sending commands to Fwmark
89 Stopwatch s;
90 const int ret = libcConnect(sockfd, addr, addrlen);
91 // Save errno so it isn't clobbered by sending ON_CONNECT_COMPLETE
92 const int connectErrno = errno;
93 const unsigned latencyMs = lround(s.timeTaken());
94 // Send an ON_CONNECT_COMPLETE command that includes sockaddr and connect latency for reporting
95 if (shouldSetFwmark && FwmarkClient::shouldReportConnectComplete(addr->sa_family)) {
96 FwmarkConnectInfo connectInfo(ret == 0 ? 0 : connectErrno, latencyMs, addr);
97 // TODO: get the netId from the socket mark once we have continuous benchmark runs
98 FwmarkCommand command = {FwmarkCommand::ON_CONNECT_COMPLETE, /* netId (ignored) */ 0,
99 /* uid (filled in by the server) */ 0, 0};
100 // Ignore return value since it's only used for logging
101 FwmarkClient().send(&command, sockfd, &connectInfo);
102 }
103 errno = connectErrno;
104 return ret;
105 }
106
netdClientSocket(int domain,int type,int protocol)107 int netdClientSocket(int domain, int type, int protocol) {
108 int socketFd = libcSocket(domain, type, protocol);
109 if (socketFd == -1) {
110 return -1;
111 }
112 unsigned netId = netIdForProcess;
113 if (netId != NETID_UNSET && FwmarkClient::shouldSetFwmark(domain)) {
114 if (int error = setNetworkForSocket(netId, socketFd)) {
115 return closeFdAndSetErrno(socketFd, error);
116 }
117 }
118 return socketFd;
119 }
120
getNetworkForResolv(unsigned netId)121 unsigned getNetworkForResolv(unsigned netId) {
122 if (netId != NETID_UNSET) {
123 return netId;
124 }
125 // Special case for DNS-over-TLS bypass; b/72345192 .
126 if ((netIdForResolv & ~NETID_USE_LOCAL_NAMESERVERS) != NETID_UNSET) {
127 return netIdForResolv;
128 }
129 netId = netIdForProcess;
130 if (netId != NETID_UNSET) {
131 return netId;
132 }
133 return netIdForResolv;
134 }
135
setNetworkForTarget(unsigned netId,std::atomic_uint * target)136 int setNetworkForTarget(unsigned netId, std::atomic_uint* target) {
137 const unsigned requestedNetId = netId;
138 netId &= ~NETID_USE_LOCAL_NAMESERVERS;
139
140 if (netId == NETID_UNSET) {
141 *target = netId;
142 return 0;
143 }
144 // Verify that we are allowed to use |netId|, by creating a socket and trying to have it marked
145 // with the netId. Call libcSocket() directly; else the socket creation (via netdClientSocket())
146 // might itself cause another check with the fwmark server, which would be wasteful.
147 int socketFd;
148 if (libcSocket) {
149 socketFd = libcSocket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0);
150 } else {
151 socketFd = socket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0);
152 }
153 if (socketFd < 0) {
154 return -errno;
155 }
156 int error = setNetworkForSocket(netId, socketFd);
157 if (!error) {
158 *target = (target == &netIdForResolv) ? requestedNetId : netId;
159 }
160 close(socketFd);
161 return error;
162 }
163
164 } // namespace
165
166 // accept() just calls accept4(..., 0), so there's no need to handle accept() separately.
netdClientInitAccept4(Accept4FunctionType * function)167 extern "C" void netdClientInitAccept4(Accept4FunctionType* function) {
168 if (function && *function) {
169 libcAccept4 = *function;
170 *function = netdClientAccept4;
171 }
172 }
173
netdClientInitConnect(ConnectFunctionType * function)174 extern "C" void netdClientInitConnect(ConnectFunctionType* function) {
175 if (function && *function) {
176 libcConnect = *function;
177 *function = netdClientConnect;
178 }
179 }
180
netdClientInitSocket(SocketFunctionType * function)181 extern "C" void netdClientInitSocket(SocketFunctionType* function) {
182 if (function && *function) {
183 libcSocket = *function;
184 *function = netdClientSocket;
185 }
186 }
187
netdClientInitNetIdForResolv(NetIdForResolvFunctionType * function)188 extern "C" void netdClientInitNetIdForResolv(NetIdForResolvFunctionType* function) {
189 if (function) {
190 *function = getNetworkForResolv;
191 }
192 }
193
getNetworkForSocket(unsigned * netId,int socketFd)194 extern "C" int getNetworkForSocket(unsigned* netId, int socketFd) {
195 if (!netId || socketFd < 0) {
196 return -EBADF;
197 }
198 Fwmark fwmark;
199 socklen_t fwmarkLen = sizeof(fwmark.intValue);
200 if (getsockopt(socketFd, SOL_SOCKET, SO_MARK, &fwmark.intValue, &fwmarkLen) == -1) {
201 return -errno;
202 }
203 *netId = fwmark.netId;
204 return 0;
205 }
206
getNetworkForProcess()207 extern "C" unsigned getNetworkForProcess() {
208 return netIdForProcess;
209 }
210
setNetworkForSocket(unsigned netId,int socketFd)211 extern "C" int setNetworkForSocket(unsigned netId, int socketFd) {
212 if (socketFd < 0) {
213 return -EBADF;
214 }
215 FwmarkCommand command = {FwmarkCommand::SELECT_NETWORK, netId, 0, 0};
216 return FwmarkClient().send(&command, socketFd, nullptr);
217 }
218
setNetworkForProcess(unsigned netId)219 extern "C" int setNetworkForProcess(unsigned netId) {
220 return setNetworkForTarget(netId, &netIdForProcess);
221 }
222
setNetworkForResolv(unsigned netId)223 extern "C" int setNetworkForResolv(unsigned netId) {
224 return setNetworkForTarget(netId, &netIdForResolv);
225 }
226
protectFromVpn(int socketFd)227 extern "C" int protectFromVpn(int socketFd) {
228 if (socketFd < 0) {
229 return -EBADF;
230 }
231 FwmarkCommand command = {FwmarkCommand::PROTECT_FROM_VPN, 0, 0, 0};
232 return FwmarkClient().send(&command, socketFd, nullptr);
233 }
234
setNetworkForUser(uid_t uid,int socketFd)235 extern "C" int setNetworkForUser(uid_t uid, int socketFd) {
236 if (socketFd < 0) {
237 return -EBADF;
238 }
239 FwmarkCommand command = {FwmarkCommand::SELECT_FOR_USER, 0, uid, 0};
240 return FwmarkClient().send(&command, socketFd, nullptr);
241 }
242
queryUserAccess(uid_t uid,unsigned netId)243 extern "C" int queryUserAccess(uid_t uid, unsigned netId) {
244 FwmarkCommand command = {FwmarkCommand::QUERY_USER_ACCESS, netId, uid, 0};
245 return FwmarkClient().send(&command, -1, nullptr);
246 }
247
tagSocket(int socketFd,uint32_t tag,uid_t uid)248 extern "C" int tagSocket(int socketFd, uint32_t tag, uid_t uid) {
249 FwmarkCommand command = {FwmarkCommand::TAG_SOCKET, 0, uid, tag};
250 return FwmarkClient().send(&command, socketFd, nullptr);
251 }
252
untagSocket(int socketFd)253 extern "C" int untagSocket(int socketFd) {
254 FwmarkCommand command = {FwmarkCommand::UNTAG_SOCKET, 0, 0, 0};
255 return FwmarkClient().send(&command, socketFd, nullptr);
256 }
257
setCounterSet(uint32_t counterSet,uid_t uid)258 extern "C" int setCounterSet(uint32_t counterSet, uid_t uid) {
259 FwmarkCommand command = {FwmarkCommand::SET_COUNTERSET, 0, uid, counterSet};
260 return FwmarkClient().send(&command, -1, nullptr);
261 }
262
deleteTagData(uint32_t tag,uid_t uid)263 extern "C" int deleteTagData(uint32_t tag, uid_t uid) {
264 FwmarkCommand command = {FwmarkCommand::DELETE_TAGDATA, 0, uid, tag};
265 return FwmarkClient().send(&command, -1, nullptr);
266 }
267