1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.net.module.util.netlink;
18 
19 import static android.os.Process.INVALID_UID;
20 import static android.system.OsConstants.AF_INET;
21 import static android.system.OsConstants.AF_INET6;
22 import static android.system.OsConstants.ENOENT;
23 import static android.system.OsConstants.IPPROTO_TCP;
24 import static android.system.OsConstants.IPPROTO_UDP;
25 import static android.system.OsConstants.NETLINK_INET_DIAG;
26 
27 import static com.android.net.module.util.netlink.NetlinkConstants.SOCK_DESTROY;
28 import static com.android.net.module.util.netlink.NetlinkConstants.SOCK_DIAG_BY_FAMILY;
29 import static com.android.net.module.util.netlink.NetlinkConstants.SOCKDIAG_MSG_HEADER_SIZE;
30 import static com.android.net.module.util.netlink.NetlinkConstants.stringForAddressFamily;
31 import static com.android.net.module.util.netlink.NetlinkConstants.stringForProtocol;
32 import static com.android.net.module.util.netlink.NetlinkUtils.DEFAULT_RECV_BUFSIZE;
33 import static com.android.net.module.util.netlink.NetlinkUtils.IO_TIMEOUT_MS;
34 import static com.android.net.module.util.netlink.NetlinkUtils.SOCKET_RECV_BUFSIZE;
35 import static com.android.net.module.util.netlink.NetlinkUtils.TCP_ALIVE_STATE_FILTER;
36 import static com.android.net.module.util.netlink.NetlinkUtils.connectToKernel;
37 import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_DUMP;
38 import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_REQUEST;
39 
40 import android.net.util.SocketUtils;
41 import android.os.Process;
42 import android.os.SystemClock;
43 import android.system.ErrnoException;
44 import android.util.Log;
45 import android.util.Range;
46 
47 import androidx.annotation.NonNull;
48 import androidx.annotation.Nullable;
49 import androidx.annotation.VisibleForTesting;
50 
51 import java.io.FileDescriptor;
52 import java.io.IOException;
53 import java.io.InterruptedIOException;
54 import java.net.Inet4Address;
55 import java.net.Inet6Address;
56 import java.net.InetAddress;
57 import java.net.InetSocketAddress;
58 import java.net.SocketException;
59 import java.net.UnknownHostException;
60 import java.nio.ByteBuffer;
61 import java.nio.ByteOrder;
62 import java.util.ArrayList;
63 import java.util.List;
64 import java.util.Set;
65 import java.util.concurrent.atomic.AtomicInteger;
66 import java.util.function.Consumer;
67 import java.util.function.Predicate;
68 
69 /**
70  * A NetlinkMessage subclass for netlink inet_diag messages.
71  *
72  * see also: <linux_src>/include/uapi/linux/inet_diag.h
73  *
74  * @hide
75  */
76 public class InetDiagMessage extends NetlinkMessage {
77     public static final String TAG = "InetDiagMessage";
78     private static final int TIMEOUT_MS = 500;
79 
80     /**
81      * Construct an inet_diag_req_v2 message. This method will throw
82      * {@link IllegalArgumentException} if local and remote are not both null or both non-null.
83      */
inetDiagReqV2(int protocol, InetSocketAddress local, InetSocketAddress remote, int family, short flags)84     public static byte[] inetDiagReqV2(int protocol, InetSocketAddress local,
85             InetSocketAddress remote, int family, short flags) {
86         return inetDiagReqV2(protocol, local, remote, family, flags, 0 /* pad */,
87                 0 /* idiagExt */, StructInetDiagReqV2.INET_DIAG_REQ_V2_ALL_STATES);
88     }
89 
90     /**
91      * Construct an inet_diag_req_v2 message. This method will throw
92      * {@code IllegalArgumentException} if local and remote are not both null or both non-null.
93      *
94      * @param protocol the request protocol type. This should be set to one of IPPROTO_TCP,
95      *                 IPPROTO_UDP, or IPPROTO_UDPLITE.
96      * @param local local socket address of the target socket. This will be packed into a
97      *              {@link StructInetDiagSockId}. Request to diagnose for all sockets if both of
98      *              local or remote address is null.
99      * @param remote remote socket address of the target socket. This will be packed into a
100      *              {@link StructInetDiagSockId}. Request to diagnose for all sockets if both of
101      *              local or remote address is null.
102      * @param family the ip family of the request message. This should be set to either AF_INET or
103      *               AF_INET6 for IPv4 or IPv6 sockets respectively.
104      * @param flags message flags. See <linux_src>/include/uapi/linux/netlink.h.
105      * @param pad for raw socket protocol specification.
106      * @param idiagExt a set of flags defining what kind of extended information to report.
107      * @param state a bit mask that defines a filter of socket states.
108      *
109      * @return bytes array representation of the message
110      */
inetDiagReqV2(int protocol, @Nullable InetSocketAddress local, @Nullable InetSocketAddress remote, int family, short flags, int pad, int idiagExt, int state)111     public static byte[] inetDiagReqV2(int protocol, @Nullable InetSocketAddress local,
112             @Nullable InetSocketAddress remote, int family, short flags, int pad, int idiagExt,
113             int state) throws IllegalArgumentException {
114         // Request for all sockets if no specific socket is requested. Specify the local and remote
115         // socket address information for target request socket.
116         if ((local == null) != (remote == null)) {
117             throw new IllegalArgumentException(
118                     "Local and remote must be both null or both non-null");
119         }
120         final StructInetDiagSockId id = ((local != null && remote != null)
121                 ? new StructInetDiagSockId(local, remote) : null);
122         return inetDiagReqV2(protocol, id, family,
123                 SOCK_DIAG_BY_FAMILY, flags, pad, idiagExt, state);
124     }
125 
126     /**
127      * Construct an inet_diag_req_v2 message.
128      *
129      * @param protocol the request protocol type. This should be set to one of IPPROTO_TCP,
130      *                 IPPROTO_UDP, or IPPROTO_UDPLITE.
131      * @param id inet_diag_sockid. See {@link StructInetDiagSockId}
132      * @param family the ip family of the request message. This should be set to either AF_INET or
133      *               AF_INET6 for IPv4 or IPv6 sockets respectively.
134      * @param type message types.
135      * @param flags message flags. See <linux_src>/include/uapi/linux/netlink.h.
136      * @param pad for raw socket protocol specification.
137      * @param idiagExt a set of flags defining what kind of extended information to report.
138      * @param state a bit mask that defines a filter of socket states.
139      * @return bytes array representation of the message
140      */
inetDiagReqV2(int protocol, @Nullable StructInetDiagSockId id, int family, short type, short flags, int pad, int idiagExt, int state)141     public static byte[] inetDiagReqV2(int protocol, @Nullable StructInetDiagSockId id, int family,
142             short type, short flags, int pad, int idiagExt, int state) {
143         final byte[] bytes = new byte[StructNlMsgHdr.STRUCT_SIZE + StructInetDiagReqV2.STRUCT_SIZE];
144         final ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
145         byteBuffer.order(ByteOrder.nativeOrder());
146 
147         final StructNlMsgHdr nlMsgHdr = new StructNlMsgHdr();
148         nlMsgHdr.nlmsg_len = bytes.length;
149         nlMsgHdr.nlmsg_type = type;
150         nlMsgHdr.nlmsg_flags = flags;
151         nlMsgHdr.pack(byteBuffer);
152         final StructInetDiagReqV2 inetDiagReqV2 =
153                 new StructInetDiagReqV2(protocol, id, family, pad, idiagExt, state);
154 
155         inetDiagReqV2.pack(byteBuffer);
156         return bytes;
157     }
158 
159     public StructInetDiagMsg inetDiagMsg;
160     // The netlink attributes.
161     public List<StructNlAttr> nlAttrs = new ArrayList<>();
162     @VisibleForTesting
InetDiagMessage(@onNull StructNlMsgHdr header)163     public InetDiagMessage(@NonNull StructNlMsgHdr header) {
164         super(header);
165         inetDiagMsg = new StructInetDiagMsg();
166     }
167 
168     /**
169      * Parse an inet_diag_req_v2 message from buffer.
170      */
171     @Nullable
parse(@onNull StructNlMsgHdr header, @NonNull ByteBuffer byteBuffer)172     public static InetDiagMessage parse(@NonNull StructNlMsgHdr header,
173             @NonNull ByteBuffer byteBuffer) {
174         final InetDiagMessage msg = new InetDiagMessage(header);
175         msg.inetDiagMsg = StructInetDiagMsg.parse(byteBuffer);
176         if (msg.inetDiagMsg == null) {
177             return null;
178         }
179         final int payloadLength = header.nlmsg_len - SOCKDIAG_MSG_HEADER_SIZE;
180         final ByteBuffer payload = byteBuffer.slice();
181         while (payload.position() < payloadLength) {
182             final StructNlAttr attr = StructNlAttr.parse(payload);
183             // Stop parsing for truncated or malformed attribute
184             if (attr == null)  {
185                 Log.wtf(TAG, "Got truncated or malformed attribute");
186                 return null;
187             }
188 
189             msg.nlAttrs.add(attr);
190         }
191 
192         return msg;
193     }
194 
closeSocketQuietly(final FileDescriptor fd)195     private static void closeSocketQuietly(final FileDescriptor fd) {
196         try {
197             SocketUtils.closeSocket(fd);
198         } catch (IOException ignored) {
199         }
200     }
201 
lookupUidByFamily(int protocol, InetSocketAddress local, InetSocketAddress remote, int family, short flags, FileDescriptor fd)202     private static int lookupUidByFamily(int protocol, InetSocketAddress local,
203                                          InetSocketAddress remote, int family, short flags,
204                                          FileDescriptor fd)
205             throws ErrnoException, InterruptedIOException {
206         byte[] msg = inetDiagReqV2(protocol, local, remote, family, flags);
207         NetlinkUtils.sendMessage(fd, msg, 0, msg.length, TIMEOUT_MS);
208         ByteBuffer response = NetlinkUtils.recvMessage(fd, DEFAULT_RECV_BUFSIZE, TIMEOUT_MS);
209 
210         final NetlinkMessage nlMsg = NetlinkMessage.parse(response, NETLINK_INET_DIAG);
211         if (nlMsg == null) {
212             return INVALID_UID;
213         }
214         final StructNlMsgHdr hdr = nlMsg.getHeader();
215         if (hdr.nlmsg_type == NetlinkConstants.NLMSG_DONE) {
216             return INVALID_UID;
217         }
218         if (nlMsg instanceof InetDiagMessage) {
219             return ((InetDiagMessage) nlMsg).inetDiagMsg.idiag_uid;
220         }
221         return INVALID_UID;
222     }
223 
224     private static final int[] FAMILY = {AF_INET6, AF_INET};
225 
lookupUid(int protocol, InetSocketAddress local, InetSocketAddress remote, FileDescriptor fd)226     private static int lookupUid(int protocol, InetSocketAddress local,
227                                  InetSocketAddress remote, FileDescriptor fd)
228             throws ErrnoException, InterruptedIOException {
229         int uid;
230 
231         for (int family : FAMILY) {
232             /**
233              * For exact match lookup, swap local and remote for UDP lookups due to kernel
234              * bug which will not be fixed. See aosp/755889 and
235              * https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html
236              */
237             if (protocol == IPPROTO_UDP) {
238                 uid = lookupUidByFamily(protocol, remote, local, family, NLM_F_REQUEST, fd);
239             } else {
240                 uid = lookupUidByFamily(protocol, local, remote, family, NLM_F_REQUEST, fd);
241             }
242             if (uid != INVALID_UID) {
243                 return uid;
244             }
245         }
246 
247         /**
248          * For UDP it's possible for a socket to send packets to arbitrary destinations, even if the
249          * socket is not connected (and even if the socket is connected to a different destination).
250          * If we want this API to work for such packets, then on miss we need to do a second lookup
251          * with only the local address and port filled in.
252          * Always use flags == NLM_F_REQUEST | NLM_F_DUMP for wildcard.
253          */
254         if (protocol == IPPROTO_UDP) {
255             try {
256                 InetSocketAddress wildcard = new InetSocketAddress(
257                         Inet6Address.getByName("::"), 0);
258                 uid = lookupUidByFamily(protocol, local, wildcard, AF_INET6,
259                         (short) (NLM_F_REQUEST | NLM_F_DUMP), fd);
260                 if (uid != INVALID_UID) {
261                     return uid;
262                 }
263                 wildcard = new InetSocketAddress(Inet4Address.getByName("0.0.0.0"), 0);
264                 uid = lookupUidByFamily(protocol, local, wildcard, AF_INET,
265                         (short) (NLM_F_REQUEST | NLM_F_DUMP), fd);
266                 if (uid != INVALID_UID) {
267                     return uid;
268                 }
269             } catch (UnknownHostException e) {
270                 Log.e(TAG, e.toString());
271             }
272         }
273         return INVALID_UID;
274     }
275 
276     /**
277      * Use an inet_diag socket to look up the UID associated with the input local and remote
278      * address/port and protocol of a connection.
279      */
getConnectionOwnerUid(int protocol, InetSocketAddress local, InetSocketAddress remote)280     public static int getConnectionOwnerUid(int protocol, InetSocketAddress local,
281                                             InetSocketAddress remote) {
282         int uid = INVALID_UID;
283         FileDescriptor fd = null;
284         try {
285             fd = NetlinkUtils.netlinkSocketForProto(NETLINK_INET_DIAG, SOCKET_RECV_BUFSIZE);
286             connectToKernel(fd);
287             uid = lookupUid(protocol, local, remote, fd);
288         } catch (ErrnoException | SocketException | IllegalArgumentException
289                 | InterruptedIOException e) {
290             Log.e(TAG, e.toString());
291         } finally {
292             closeSocketQuietly(fd);
293         }
294         return uid;
295     }
296 
297     /**
298      * Construct an inet_diag_req_v2 message for querying alive TCP sockets from kernel.
299      */
buildInetDiagReqForAliveTcpSockets(int family)300     public static byte[] buildInetDiagReqForAliveTcpSockets(int family) {
301         return inetDiagReqV2(IPPROTO_TCP,
302                 null /* local addr */,
303                 null /* remote addr */,
304                 family,
305                 (short) (StructNlMsgHdr.NLM_F_REQUEST | StructNlMsgHdr.NLM_F_DUMP) /* flag */,
306                 0 /* pad */,
307                 1 << NetlinkConstants.INET_DIAG_MEMINFO /* idiagExt */,
308                 TCP_ALIVE_STATE_FILTER);
309     }
310 
sendNetlinkDestroyRequest(FileDescriptor fd, int proto, InetDiagMessage diagMsg)311     private static void sendNetlinkDestroyRequest(FileDescriptor fd, int proto,
312             InetDiagMessage diagMsg) throws InterruptedIOException, ErrnoException {
313         final byte[] destroyMsg = InetDiagMessage.inetDiagReqV2(
314                 proto,
315                 diagMsg.inetDiagMsg.id,
316                 diagMsg.inetDiagMsg.idiag_family,
317                 SOCK_DESTROY,
318                 (short) (StructNlMsgHdr.NLM_F_REQUEST | StructNlMsgHdr.NLM_F_ACK),
319                 0 /* pad */,
320                 0 /* idiagExt */,
321                 1 << diagMsg.inetDiagMsg.idiag_state
322         );
323         NetlinkUtils.sendMessage(fd, destroyMsg, 0, destroyMsg.length, IO_TIMEOUT_MS);
324         NetlinkUtils.receiveNetlinkAck(fd);
325     }
326 
makeNetlinkDumpRequest(int proto, int states, int family)327     private static byte [] makeNetlinkDumpRequest(int proto, int states, int family) {
328         return InetDiagMessage.inetDiagReqV2(
329                 proto,
330                 null /* id */,
331                 family,
332                 SOCK_DIAG_BY_FAMILY,
333                 (short) (StructNlMsgHdr.NLM_F_REQUEST | StructNlMsgHdr.NLM_F_DUMP),
334                 0 /* pad */,
335                 0 /* idiagExt */,
336                 states);
337     }
338 
processNetlinkDumpAndDestroySockets(byte[] dumpReq, FileDescriptor destroyFd, int proto, Predicate<InetDiagMessage> filter)339     private static int processNetlinkDumpAndDestroySockets(byte[] dumpReq,
340             FileDescriptor destroyFd, int proto, Predicate<InetDiagMessage> filter)
341             throws SocketException, InterruptedIOException, ErrnoException {
342         AtomicInteger destroyedSockets = new AtomicInteger(0);
343         Consumer<InetDiagMessage> handleNlDumpMsg = (diagMsg) -> {
344             if (filter.test(diagMsg)) {
345                 try {
346                     sendNetlinkDestroyRequest(destroyFd, proto, diagMsg);
347                     destroyedSockets.getAndIncrement();
348                 } catch (InterruptedIOException | ErrnoException e) {
349                     if (!(e instanceof ErrnoException
350                             && ((ErrnoException) e).errno == ENOENT)) {
351                         Log.e(TAG, "Failed to destroy socket: diagMsg=" + diagMsg + ", " + e);
352                     }
353                 }
354             }
355         };
356 
357         NetlinkUtils.<InetDiagMessage>getAndProcessNetlinkDumpMessages(dumpReq,
358                 NETLINK_INET_DIAG, InetDiagMessage.class, handleNlDumpMsg);
359         return destroyedSockets.get();
360     }
361 
362     /**
363      * Returns whether the InetDiagMessage is for adb socket or not
364      */
365     @VisibleForTesting
isAdbSocket(final InetDiagMessage msg)366     public static boolean isAdbSocket(final InetDiagMessage msg) {
367         // This is inaccurate since adb could run with ROOT_UID or other services can run with
368         // SHELL_UID. But this check covers most cases and enough.
369         // Note that getting service.adb.tcp.port system property is prohibited by sepolicy
370         // TODO: skip the socket only if there is a listen socket owned by SHELL_UID with the same
371         // source port as this socket
372         return msg.inetDiagMsg.idiag_uid == Process.SHELL_UID;
373     }
374 
375     /**
376      * Returns whether the range contains the uid in the InetDiagMessage or not
377      */
378     @VisibleForTesting
containsUid(InetDiagMessage msg, Set<Range<Integer>> ranges)379     public static boolean containsUid(InetDiagMessage msg, Set<Range<Integer>> ranges) {
380         for (final Range<Integer> range: ranges) {
381             if (range.contains(msg.inetDiagMsg.idiag_uid)) {
382                 return true;
383             }
384         }
385         return false;
386     }
387 
isLoopbackAddress(InetAddress addr)388     private static boolean isLoopbackAddress(InetAddress addr) {
389         if (addr.isLoopbackAddress()) return true;
390         if (!(addr instanceof Inet6Address)) return false;
391 
392         // Following check is for v4-mapped v6 address. StructInetDiagSockId contains v4-mapped v6
393         // address as Inet6Address, See StructInetDiagSockId#parse
394         final byte[] addrBytes = addr.getAddress();
395         for (int i = 0; i < 10; i++) {
396             if (addrBytes[i] != 0) return false;
397         }
398         return addrBytes[10] == (byte) 0xff
399                 && addrBytes[11] == (byte) 0xff
400                 && addrBytes[12] == 127;
401     }
402 
403     /**
404      * Returns whether the socket address in the InetDiagMessage is loopback or not
405      */
406     @VisibleForTesting
isLoopback(InetDiagMessage msg)407     public static boolean isLoopback(InetDiagMessage msg) {
408         final InetAddress srcAddr = msg.inetDiagMsg.id.locSocketAddress.getAddress();
409         final InetAddress dstAddr = msg.inetDiagMsg.id.remSocketAddress.getAddress();
410         return isLoopbackAddress(srcAddr)
411                 || isLoopbackAddress(dstAddr)
412                 || srcAddr.equals(dstAddr);
413     }
414 
destroySockets(int proto, int states, Predicate<InetDiagMessage> filter)415     private static void destroySockets(int proto, int states, Predicate<InetDiagMessage> filter)
416             throws ErrnoException, SocketException, InterruptedIOException {
417         FileDescriptor destroyFd = null;
418 
419         try {
420             destroyFd = NetlinkUtils.createNetLinkInetDiagSocket();
421             connectToKernel(destroyFd);
422 
423             for (int family : List.of(AF_INET, AF_INET6)) {
424                 byte[] req = makeNetlinkDumpRequest(proto, states, family);
425 
426                 try {
427                     final int destroyedSockets = processNetlinkDumpAndDestroySockets(
428                             req, destroyFd, proto, filter);
429                     Log.d(TAG, "Destroyed " + destroyedSockets + " sockets"
430                         + ", proto=" + stringForProtocol(proto)
431                         + ", family=" + stringForAddressFamily(family)
432                         + ", states=" + states);
433                 } catch (SocketException | InterruptedIOException | ErrnoException e) {
434                     Log.e(TAG, "Failed to send netlink dump request or receive messages: " + e);
435                     continue;
436                 }
437             }
438         } finally {
439             closeSocketQuietly(destroyFd);
440         }
441     }
442 
443     /**
444      * Close tcp sockets that match the following condition
445      *  1. TCP status is one of TCP_ESTABLISHED, TCP_SYN_SENT, and TCP_SYN_RECV
446      *  2. Owner uid of socket is not in the exemptUids
447      *  3. Owner uid of socket is in the ranges
448      *  4. Socket is not loopback
449      *  5. Socket is not adb socket
450      *
451      * @param ranges target uid ranges
452      * @param exemptUids uids to skip close socket
453      */
destroyLiveTcpSockets(Set<Range<Integer>> ranges, Set<Integer> exemptUids)454     public static void destroyLiveTcpSockets(Set<Range<Integer>> ranges, Set<Integer> exemptUids)
455             throws SocketException, InterruptedIOException, ErrnoException {
456         final long startTimeMs = SystemClock.elapsedRealtime();
457         destroySockets(IPPROTO_TCP, TCP_ALIVE_STATE_FILTER,
458                 (diagMsg) -> !exemptUids.contains(diagMsg.inetDiagMsg.idiag_uid)
459                         && containsUid(diagMsg, ranges)
460                         && !isLoopback(diagMsg)
461                         && !isAdbSocket(diagMsg));
462         final long durationMs = SystemClock.elapsedRealtime() - startTimeMs;
463         Log.d(TAG, "Destroyed live tcp sockets for uids=" + ranges + " exemptUids=" + exemptUids
464                 + " in " + durationMs + "ms");
465     }
466 
467     /**
468      * Close tcp sockets that match the following condition
469      *  1. TCP status is one of TCP_ESTABLISHED, TCP_SYN_SENT, and TCP_SYN_RECV
470      *  2. Owner uid of socket is in the targetUids
471      *  3. Socket is not loopback
472      *  4. Socket is not adb socket
473      *
474      * @param ownerUids target uids to close sockets
475      */
destroyLiveTcpSocketsByOwnerUids(Set<Integer> ownerUids)476     public static void destroyLiveTcpSocketsByOwnerUids(Set<Integer> ownerUids)
477             throws SocketException, InterruptedIOException, ErrnoException {
478         final long startTimeMs = SystemClock.elapsedRealtime();
479         destroySockets(IPPROTO_TCP, TCP_ALIVE_STATE_FILTER,
480                 (diagMsg) -> ownerUids.contains(diagMsg.inetDiagMsg.idiag_uid)
481                         && !isLoopback(diagMsg)
482                         && !isAdbSocket(diagMsg));
483         final long durationMs = SystemClock.elapsedRealtime() - startTimeMs;
484         Log.d(TAG, "Destroyed live tcp sockets for uids=" + ownerUids + " in " + durationMs + "ms");
485     }
486 
487     @Override
toString()488     public String toString() {
489         return "InetDiagMessage{ "
490                 + "nlmsghdr{"
491                 + (mHeader == null ? "" : mHeader.toString(NETLINK_INET_DIAG)) + "}, "
492                 + "inet_diag_msg{"
493                 + (inetDiagMsg == null ? "" : inetDiagMsg.toString()) + "} "
494                 + "}";
495     }
496 }
497