1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.networkstack.netlink;
18 
19 import static android.net.netlink.NetlinkConstants.SOCKDIAG_MSG_HEADER_SIZE;
20 import static android.net.util.DataStallUtils.CONFIG_TCP_PACKETS_FAIL_PERCENTAGE;
21 import static android.net.util.DataStallUtils.DEFAULT_TCP_PACKETS_FAIL_PERCENTAGE;
22 import static android.provider.DeviceConfig.NAMESPACE_CONNECTIVITY;
23 import static android.system.OsConstants.AF_INET;
24 
25 import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
26 
27 import static com.android.dx.mockito.inline.extended.ExtendedMockito.doReturn;
28 import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession;
29 
30 import static junit.framework.Assert.assertEquals;
31 import static junit.framework.Assert.assertFalse;
32 import static junit.framework.Assert.assertTrue;
33 
34 import static org.junit.Assume.assumeTrue;
35 import static org.mockito.ArgumentMatchers.anyInt;
36 import static org.mockito.ArgumentMatchers.eq;
37 import static org.mockito.Mockito.any;
38 import static org.mockito.Mockito.atLeastOnce;
39 import static org.mockito.Mockito.verify;
40 import static org.mockito.Mockito.verifyNoMoreInteractions;
41 import static org.mockito.Mockito.when;
42 
43 import android.net.INetd;
44 import android.net.MarkMaskParcel;
45 import android.net.Network;
46 import android.net.netlink.StructNlMsgHdr;
47 import android.os.Build;
48 import android.util.Log;
49 import android.util.Log.TerribleFailureHandler;
50 
51 import androidx.test.filters.SmallTest;
52 import androidx.test.runner.AndroidJUnit4;
53 
54 import com.android.networkstack.apishim.ConstantsShim;
55 import com.android.networkstack.apishim.NetworkShimImpl;
56 import com.android.testutils.DevSdkIgnoreRule;
57 import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter;
58 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
59 
60 import libcore.util.HexEncoding;
61 
62 import org.junit.After;
63 import org.junit.Before;
64 import org.junit.Rule;
65 import org.junit.Test;
66 import org.junit.runner.RunWith;
67 import org.mockito.Mock;
68 import org.mockito.MockitoAnnotations;
69 import org.mockito.MockitoSession;
70 import org.mockito.quality.Strictness;
71 
72 import java.io.FileDescriptor;
73 import java.nio.ByteBuffer;
74 import java.nio.ByteOrder;
75 
76 // TODO: Add more tests for missing coverage.
77 @RunWith(AndroidJUnit4.class)
78 @SmallTest
79 public class TcpSocketTrackerTest {
80     private static final int TEST_BUFFER_SIZE = 1024;
81     private static final String DIAG_MSG_HEX =
82             // struct nlmsghdr.
83             "58000000" +      // length = 88
84             "1400" +         // type = SOCK_DIAG_BY_FAMILY
85             "0301" +         // flags = NLM_F_REQUEST | NLM_F_DUMP
86             "00000000" +     // seqno
87             "00000000" +     // pid (0 == kernel)
88             // struct inet_diag_req_v2
89             "02" +           // family = AF_INET
90             "06" +           // state
91             "00" +           // timer
92             "00" +           // retrans
93             // inet_diag_sockid
94             "DEA5" +         // idiag_sport = 42462
95             "71B9" +         // idiag_dport = 47473
96             "0a006402000000000000000000000000" + // idiag_src = 10.0.100.2
97             "08080808000000000000000000000000" + // idiag_dst = 8.8.8.8
98             "00000000" +    // idiag_if
99             "34ED000076270000" + // idiag_cookie = 43387759684916
100             "00000000" +    // idiag_expires
101             "00000000" +    // idiag_rqueue
102             "00000000" +    // idiag_wqueue
103             "00000000" +    // idiag_uid
104             "00000000";    // idiag_inode
105     private static final byte[] SOCK_DIAG_MSG_BYTES =
106             HexEncoding.decode(DIAG_MSG_HEX.toCharArray(), false);
107     // Hexadecimal representation of a SOCK_DIAG response with tcp info.
108     private static final String SOCK_DIAG_TCP_INET_HEX =
109             // struct nlmsghdr.
110             "14010000" +        // length = 276
111             "1400" +            // type = SOCK_DIAG_BY_FAMILY
112             "0301" +            // flags = NLM_F_REQUEST | NLM_F_DUMP
113             "00000000" +        // seqno
114             "00000000" +        // pid (0 == kernel)
115             // struct inet_diag_req_v2
116             "02" +              // family = AF_INET
117             "06" +              // state
118             "00" +              // timer
119             "00" +              // retrans
120             // inet_diag_sockid
121             "DEA5" +            // idiag_sport = 42462
122             "71B9" +            // idiag_dport = 47473
123             "0a006402000000000000000000000000" + // idiag_src = 10.0.100.2
124             "08080808000000000000000000000000" + // idiag_dst = 8.8.8.8
125             "00000000" +            // idiag_if
126             "34ED000076270000" +    // idiag_cookie = 43387759684916
127             "00000000" +            // idiag_expires
128             "00000000" +            // idiag_rqueue
129             "00000000" +            // idiag_wqueue
130             "00000000" +            // idiag_uid
131             "00000000" +            // idiag_inode
132             // rtattr
133             "0500" +            // len = 5
134             "0800" +            // type = 8
135             "00000000" +        // data
136             "0800" +            // len = 8
137             "0F00" +            // type = 15(INET_DIAG_MARK)
138             "850A0C00" +        // data, socket mark=789125
139             "AC00" +            // len = 172
140             "0200" +            // type = 2(INET_DIAG_INFO)
141             // tcp_info
142             "01" +              // state = TCP_ESTABLISHED
143             "00" +              // ca_state = TCP_CA_OPEN
144             "05" +              // retransmits = 5
145             "00" +              // probes = 0
146             "00" +              // backoff = 0
147             "07" +              // option = TCPI_OPT_WSCALE|TCPI_OPT_SACK|TCPI_OPT_TIMESTAMPS
148             "88" +              // wscale = 8
149             "00" +              // delivery_rate_app_limited = 0
150             "4A911B00" +        // rto = 1806666
151             "00000000" +        // ato = 0
152             "2E050000" +        // sndMss = 1326
153             "18020000" +        // rcvMss = 536
154             "00000000" +        // unsacked = 0
155             "00000000" +        // acked = 0
156             "00000000" +        // lost = 0
157             "00000000" +        // retrans = 0
158             "00000000" +        // fackets = 0
159             "BB000000" +        // lastDataSent = 187
160             "00000000" +        // lastAckSent = 0
161             "BB000000" +        // lastDataRecv = 187
162             "BB000000" +        // lastDataAckRecv = 187
163             "DC050000" +        // pmtu = 1500
164             "30560100" +        // rcvSsthresh = 87600
165             "3E2C0900" +        // rttt = 601150
166             "1F960400" +        // rttvar = 300575
167             "78050000" +        // sndSsthresh = 1400
168             "0A000000" +        // sndCwnd = 10
169             "A8050000" +        // advmss = 1448
170             "03000000" +        // reordering = 3
171             "00000000" +        // rcvrtt = 0
172             "30560100" +        // rcvspace = 87600
173             "00000000" +        // totalRetrans = 0
174             "53AC000000000000" +    // pacingRate = 44115
175             "FFFFFFFFFFFFFFFF" +    // maxPacingRate = 18446744073709551615
176             "0100000000000000" +    // bytesAcked = 1
177             "0000000000000000" +    // bytesReceived = 0
178             "0A000000" +        // SegsOut = 10
179             "00000000" +        // SegsIn = 0
180             "00000000" +        // NotSentBytes = 0
181             "3E2C0900" +        // minRtt = 601150
182             "00000000" +        // DataSegsIn = 0
183             "00000000" +        // DataSegsOut = 0
184             "0000000000000000"; // deliverRate = 0
185     private static final byte[] SOCK_DIAG_TCP_INET_BYTES =
186             HexEncoding.decode(SOCK_DIAG_TCP_INET_HEX.toCharArray(), false);
187     private static final TcpInfo TEST_TCPINFO =
188             new TcpInfo(5 /* retransmits */, 0 /* lost */, 10 /* segsOut */, 0 /* segsIn */);
189 
190     private static final String TEST_RESPONSE_HEX = SOCK_DIAG_TCP_INET_HEX
191             // struct nlmsghdr
192             + "14000000"     // length = 20
193             + "0300"         // type = NLMSG_DONE
194             + "0301"         // flags = NLM_F_REQUEST | NLM_F_DUMP
195             + "00000000"     // seqno
196             + "00000000"     // pid (0 == kernel)
197             // struct inet_diag_req_v2
198             + "02"           // family = AF_INET
199             + "06"           // state
200             + "00"           // timer
201             + "00";          // retrans
202     private static final byte[] TEST_RESPONSE_BYTES =
203             HexEncoding.decode(TEST_RESPONSE_HEX.toCharArray(), false);
204     private static final int TEST_NETID1 = 0xA85;
205     private static final int TEST_NETID2 = 0x1A85;
206     private static final int TEST_NETID1_FWMARK = 0x0A85;
207     private static final int TEST_NETID2_FWMARK = 0x1A85;
208     private static final int NETID_MASK = 0xffff;
209     @Mock private TcpSocketTracker.Dependencies mDependencies;
210     @Mock private FileDescriptor mMockFd;
211     @Mock private INetd mNetd;
212     private final Network mNetwork = new Network(TEST_NETID1);
213     private final Network mOtherNetwork = new Network(TEST_NETID2);
214     private MockitoSession mSession;
215     private TerribleFailureHandler mOldWtfHandler;
216 
217     @Rule
218     public final DevSdkIgnoreRule mIgnoreRule = new DevSdkIgnoreRule();
219 
220     @Before
setUp()221     public void setUp() throws Exception {
222         MockitoAnnotations.initMocks(this);
223         // Override the default TerribleFailureHandler, as that handler might terminate the process
224         // (if we're on an eng build).
225         mOldWtfHandler =
226                 Log.setWtfHandler((tag, what, system) -> Log.e(tag, what.getMessage(), what));
227         when(mDependencies.getNetd()).thenReturn(mNetd);
228         when(mDependencies.isTcpInfoParsingSupported()).thenReturn(true);
229         when(mDependencies.connectToKernel()).thenReturn(mMockFd);
230         when(mDependencies.getDeviceConfigPropertyInt(
231                 eq(NAMESPACE_CONNECTIVITY),
232                 eq(CONFIG_TCP_PACKETS_FAIL_PERCENTAGE),
233                 anyInt())).thenReturn(DEFAULT_TCP_PACKETS_FAIL_PERCENTAGE);
234         mSession = mockitoSession()
235                 .spyStatic(NetworkShimImpl.class)
236                 .strictness(Strictness.WARN)
237                 .startMocking();
238 
239         when(mNetd.getFwmarkForNetwork(eq(TEST_NETID1)))
240                 .thenReturn(makeMarkMaskParcel(NETID_MASK, TEST_NETID1_FWMARK));
241     }
242 
243     @After
tearDown()244     public void tearDown() {
245         mSession.finishMocking();
246         Log.setWtfHandler(mOldWtfHandler);
247     }
248 
makeMarkMaskParcel(final int mask, final int mark)249     private MarkMaskParcel makeMarkMaskParcel(final int mask, final int mark) {
250         final MarkMaskParcel parcel = new MarkMaskParcel();
251         parcel.mask = mask;
252         parcel.mark = mark;
253         return parcel;
254     }
255 
getByteBuffer(final byte[] bytes)256     private ByteBuffer getByteBuffer(final byte[] bytes) {
257         final ByteBuffer buffer = ByteBuffer.wrap(bytes);
258         buffer.order(ByteOrder.LITTLE_ENDIAN);
259         return buffer;
260     }
261 
262     @Test
testParseSockInfo()263     public void testParseSockInfo() {
264         final ByteBuffer buffer = getByteBuffer(SOCK_DIAG_TCP_INET_BYTES);
265         final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork);
266         buffer.position(SOCKDIAG_MSG_HEADER_SIZE);
267         final TcpSocketTracker.SocketInfo parsed =
268                 tst.parseSockInfo(buffer, AF_INET, 276, 100L);
269 
270         assertEquals(parsed.tcpInfo, TEST_TCPINFO);
271         assertEquals(parsed.fwmark, 789125);
272         assertEquals(parsed.updateTime, 100);
273         assertEquals(parsed.ipFamily, AF_INET);
274     }
275 
276     @Test
testEnoughBytesRemainForValidNlMsg()277     public void testEnoughBytesRemainForValidNlMsg() {
278         final ByteBuffer buffer = ByteBuffer.allocate(TEST_BUFFER_SIZE);
279 
280         buffer.position(TEST_BUFFER_SIZE - StructNlMsgHdr.STRUCT_SIZE);
281         assertTrue(TcpSocketTracker.enoughBytesRemainForValidNlMsg(buffer));
282         // Remaining buffer size is less than a valid StructNlMsgHdr size.
283         buffer.position(TEST_BUFFER_SIZE - StructNlMsgHdr.STRUCT_SIZE + 1);
284         assertFalse(TcpSocketTracker.enoughBytesRemainForValidNlMsg(buffer));
285 
286         buffer.position(TEST_BUFFER_SIZE);
287         assertFalse(TcpSocketTracker.enoughBytesRemainForValidNlMsg(buffer));
288     }
289 
290     @Test @IgnoreUpTo(Build.VERSION_CODES.Q) // TCP info parsing is not supported on Q
testPollSocketsInfo()291     public void testPollSocketsInfo() throws Exception {
292         // This test requires shims that provide API 30 access
293         assumeTrue(ConstantsShim.VERSION >= 30);
294         when(mDependencies.isTcpInfoParsingSupported()).thenReturn(false);
295         final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork);
296         assertFalse(tst.pollSocketsInfo());
297 
298         when(mDependencies.isTcpInfoParsingSupported()).thenReturn(true);
299         // No enough bytes remain for a valid NlMsg.
300         final ByteBuffer invalidBuffer = ByteBuffer.allocate(1);
301         invalidBuffer.order(ByteOrder.LITTLE_ENDIAN);
302         when(mDependencies.recvMessage(any())).thenReturn(invalidBuffer);
303         assertTrue(tst.pollSocketsInfo());
304         assertEquals(-1, tst.getLatestPacketFailPercentage());
305         assertEquals(0, tst.getSentSinceLastRecv());
306 
307         // Header only.
308         final ByteBuffer headerBuffer = getByteBuffer(SOCK_DIAG_MSG_BYTES);
309         when(mDependencies.recvMessage(any())).thenReturn(headerBuffer);
310         assertTrue(tst.pollSocketsInfo());
311         assertEquals(-1, tst.getLatestPacketFailPercentage());
312         assertEquals(0, tst.getSentSinceLastRecv());
313 
314         final ByteBuffer tcpBuffer = getByteBuffer(TEST_RESPONSE_BYTES);
315         when(mDependencies.recvMessage(any())).thenReturn(tcpBuffer);
316         assertTrue(tst.pollSocketsInfo());
317 
318         assertEquals(10, tst.getSentSinceLastRecv());
319         assertEquals(50, tst.getLatestPacketFailPercentage());
320         assertFalse(tst.isDataStallSuspected());
321         // Lower the threshold.
322         when(mDependencies.getDeviceConfigPropertyInt(any(), eq(CONFIG_TCP_PACKETS_FAIL_PERCENTAGE),
323                 anyInt())).thenReturn(40);
324         // No device config change. Using cache value.
325         assertFalse(tst.isDataStallSuspected());
326         // Trigger a config update
327         tst.mConfigListener.onPropertiesChanged(null /* properties */);
328         assertTrue(tst.isDataStallSuspected());
329     }
330 
331     @Test
testTcpInfoParsingUnsupported()332     public void testTcpInfoParsingUnsupported() {
333         doReturn(false).when(mDependencies).isTcpInfoParsingSupported();
334         final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork);
335         verify(mDependencies).getNetd();
336 
337         assertFalse(tst.pollSocketsInfo());
338         assertEquals(-1, tst.getLatestPacketFailPercentage());
339         assertEquals(-1, tst.getLatestReceivedCount());
340         assertEquals(-1, tst.getSentSinceLastRecv());
341         assertFalse(tst.isDataStallSuspected());
342 
343         verify(mDependencies, atLeastOnce()).isTcpInfoParsingSupported();
344         verifyNoMoreInteractions(mDependencies);
345     }
346 
347     @Test @IgnoreAfter(Build.VERSION_CODES.Q)
testTcpInfoParsingNotSupportedOnQ()348     public void testTcpInfoParsingNotSupportedOnQ() {
349         assertFalse(new TcpSocketTracker.Dependencies(getInstrumentation().getContext())
350                 .isTcpInfoParsingSupported());
351     }
352 
353     @Test @IgnoreUpTo(Build.VERSION_CODES.Q)
testTcpInfoParsingSupportedFromR()354     public void testTcpInfoParsingSupportedFromR() {
355         assertTrue(new TcpSocketTracker.Dependencies(getInstrumentation().getContext())
356                 .isTcpInfoParsingSupported());
357     }
358 
359     private static final String BAD_DIAG_MSG_HEX =
360         // struct nlmsghdr.
361             "00000058" +      // length = 1476395008
362             "1400" +         // type = SOCK_DIAG_BY_FAMILY
363             "0301" +         // flags = NLM_F_REQUEST | NLM_F_DUMP
364             "00000000" +     // seqno
365             "00000000" +     // pid (0 == kernel)
366             // struct inet_diag_req_v2
367             "02" +           // family = AF_INET
368             "06" +           // state
369             "00" +           // timer
370             "00" +           // retrans
371             // inet_diag_sockid
372             "DEA5" +         // idiag_sport = 42462
373             "71B9" +         // idiag_dport = 47473
374             "0a006402000000000000000000000000" + // idiag_src = 10.0.100.2
375             "08080808000000000000000000000000" + // idiag_dst = 8.8.8.8
376             "00000000" +    // idiag_if
377             "34ED000076270000" + // idiag_cookie = 43387759684916
378             "00000000" +    // idiag_expires
379             "00000000" +    // idiag_rqueue
380             "00000000" +    // idiag_wqueue
381             "00000000" +    // idiag_uid
382             "00000000";    // idiag_inode
383     private static final byte[] BAD_SOCK_DIAG_MSG_BYTES =
384         HexEncoding.decode(BAD_DIAG_MSG_HEX.toCharArray(), false);
385 
386     @Test @IgnoreUpTo(Build.VERSION_CODES.Q) // TCP info parsing is not supported on Q
testPollSocketsInfo_BadFormat()387     public void testPollSocketsInfo_BadFormat() throws Exception {
388         // This test requires shims that provide API 30 access
389         assumeTrue(ConstantsShim.VERSION >= 30);
390         final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork);
391         ByteBuffer tcpBuffer = getByteBuffer(TEST_RESPONSE_BYTES);
392 
393         when(mDependencies.recvMessage(any())).thenReturn(tcpBuffer);
394         assertTrue(tst.pollSocketsInfo());
395         assertEquals(10, tst.getSentSinceLastRecv());
396         assertEquals(50, tst.getLatestPacketFailPercentage());
397 
398         tcpBuffer = getByteBuffer(BAD_SOCK_DIAG_MSG_BYTES);
399         when(mDependencies.recvMessage(any())).thenReturn(tcpBuffer);
400         assertTrue(tst.pollSocketsInfo());
401         // Expect no additional packets, so still 10.
402         assertEquals(10, tst.getSentSinceLastRecv());
403         // Expect to reset to 0.
404         assertEquals(0, tst.getLatestPacketFailPercentage());
405     }
406 
407     @Test
testUnMatchNetwork()408     public void testUnMatchNetwork() throws Exception {
409         when(mNetd.getFwmarkForNetwork(eq(TEST_NETID2)))
410                 .thenReturn(makeMarkMaskParcel(NETID_MASK, TEST_NETID2_FWMARK));
411         final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mOtherNetwork);
412         final ByteBuffer tcpBuffer = getByteBuffer(TEST_RESPONSE_BYTES);
413         when(mDependencies.recvMessage(any())).thenReturn(tcpBuffer);
414         assertTrue(tst.pollSocketsInfo());
415 
416         assertEquals(0, tst.getSentSinceLastRecv());
417         assertEquals(-1, tst.getLatestPacketFailPercentage());
418         assertFalse(tst.isDataStallSuspected());
419     }
420 }
421