1#!/usr/bin/python
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
18from errno import *  # pylint: disable=wildcard-import
19import os
20import random
21import select
22from socket import *  # pylint: disable=wildcard-import
23import struct
24import threading
25import time
26import unittest
27
28import cstruct
29import multinetwork_base
30import net_test
31import packets
32import sock_diag
33import tcp_test
34
35# Mostly empty structure definition containing only the fields we currently use.
36TcpInfo = cstruct.Struct("TcpInfo", "64xI", "tcpi_rcv_ssthresh")
37
38NUM_SOCKETS = 30
39NO_BYTECODE = ""
40LINUX_4_9_OR_ABOVE = net_test.LINUX_VERSION >= (4, 9, 0)
41LINUX_4_19_OR_ABOVE = net_test.LINUX_VERSION >= (4, 19, 0)
42
43IPPROTO_SCTP = 132
44
45def HaveUdpDiag():
46  """Checks if the current kernel has config CONFIG_INET_UDP_DIAG enabled.
47
48  This config is required for device running 4.9 kernel that ship with P, In
49  this case always assume the config is there and use the tests to check if the
50  config is enabled as required.
51
52  For all ther other kernel version, there is no way to tell whether a dump
53  succeeded: if the appropriate handler wasn't found, __inet_diag_dump just
54  returns an empty result instead of an error. So, just check to see if a UDP
55  dump returns no sockets when we know it should return one. If not, some tests
56  will be skipped.
57
58  Returns:
59    True if the kernel is 4.9 or above, or the CONFIG_INET_UDP_DIAG is enabled.
60    False otherwise.
61  """
62  if LINUX_4_9_OR_ABOVE:
63      return True;
64  s = socket(AF_INET6, SOCK_DGRAM, 0)
65  s.bind(("::", 0))
66  s.connect((s.getsockname()))
67  sd = sock_diag.SockDiag()
68  have_udp_diag = len(sd.DumpAllInetSockets(IPPROTO_UDP, "")) > 0
69  s.close()
70  return have_udp_diag
71
72def HaveSctp():
73  if net_test.LINUX_VERSION < (4, 7, 0):
74    return False
75  try:
76    s = socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP)
77    s.close()
78    return True
79  except IOError:
80    return False
81
82HAVE_UDP_DIAG = HaveUdpDiag()
83HAVE_SCTP = HaveSctp()
84
85
86class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
87  """Basic tests for SOCK_DIAG functionality.
88
89    Relevant kernel commits:
90      android-3.4:
91        ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields
92        99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
93
94      android-3.10:
95        3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields
96        f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
97
98      android-3.18:
99        e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
100
101      android-4.4:
102        525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
103  """
104  @staticmethod
105  def _CreateLotsOfSockets(socktype):
106    # Dict mapping (addr, sport, dport) tuples to socketpairs.
107    socketpairs = {}
108    for _ in range(NUM_SOCKETS):
109      family, addr = random.choice([
110          (AF_INET, "127.0.0.1"),
111          (AF_INET6, "::1"),
112          (AF_INET6, "::ffff:127.0.0.1")])
113      socketpair = net_test.CreateSocketPair(family, socktype, addr)
114      sport, dport = (socketpair[0].getsockname()[1],
115                      socketpair[1].getsockname()[1])
116      socketpairs[(addr, sport, dport)] = socketpair
117    return socketpairs
118
119  def assertSocketClosed(self, sock):
120    self.assertRaisesErrno(ENOTCONN, sock.getpeername)
121
122  def assertSocketConnected(self, sock):
123    sock.getpeername()  # No errors? Socket is alive and connected.
124
125  def assertSocketsClosed(self, socketpair):
126    for sock in socketpair:
127      self.assertSocketClosed(sock)
128
129  def assertMarkIs(self, mark, attrs):
130    self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None))
131
132  def assertSockInfoMatchesSocket(self, s, info):
133    diag_msg, attrs = info
134    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
135    self.assertEqual(diag_msg.family, family)
136
137    src, sport = s.getsockname()[0:2]
138    self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
139    self.assertEqual(diag_msg.id.sport, sport)
140
141    if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
142      dst, dport = s.getpeername()[0:2]
143      self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
144      self.assertEqual(diag_msg.id.dport, dport)
145    else:
146      self.assertRaisesErrno(ENOTCONN, s.getpeername)
147
148    mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
149    self.assertMarkIs(mark, attrs)
150
151  def PackAndCheckBytecode(self, instructions):
152    bytecode = self.sock_diag.PackBytecode(instructions)
153    decoded = self.sock_diag.DecodeBytecode(bytecode)
154    self.assertEqual(len(instructions), len(decoded))
155    self.assertFalse("???" in decoded)
156    return bytecode
157
158  def _EventDuringBlockingCall(self, sock, call, expected_errno, event):
159    """Simulates an external event during a blocking call on sock.
160
161    Args:
162      sock: The socket to use.
163      call: A function, the call to make. Takes one parameter, sock.
164      expected_errno: The value that call is expected to fail with, or None if
165        call is expected to succeed.
166      event: A function, the event that will happen during the blocking call.
167        Takes one parameter, sock.
168    """
169    thread = SocketExceptionThread(sock, call)
170    thread.start()
171    time.sleep(0.1)
172    event(sock)
173    thread.join(1)
174    self.assertFalse(thread.is_alive())
175    if expected_errno is not None:
176      self.assertIsNotNone(thread.exception)
177      self.assertTrue(isinstance(thread.exception, IOError),
178                      "Expected IOError, got %s" % thread.exception)
179      self.assertEqual(expected_errno, thread.exception.errno)
180    else:
181      self.assertIsNone(thread.exception)
182    self.assertSocketClosed(sock)
183
184  def CloseDuringBlockingCall(self, sock, call, expected_errno):
185    self._EventDuringBlockingCall(
186        sock, call, expected_errno,
187        lambda sock: self.sock_diag.CloseSocketFromFd(sock))
188
189  def setUp(self):
190    super(SockDiagBaseTest, self).setUp()
191    self.sock_diag = sock_diag.SockDiag()
192    self.socketpairs = {}
193
194  def tearDown(self):
195    for socketpair in list(self.socketpairs.values()):
196      for s in socketpair:
197        s.close()
198    super(SockDiagBaseTest, self).tearDown()
199
200
201class SockDiagTest(SockDiagBaseTest):
202
203  def testFindsMappedSockets(self):
204    """Tests that inet_diag_find_one_icsk can find mapped sockets."""
205    socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
206                                           "::ffff:127.0.0.1")
207    for sock in socketpair:
208      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
209      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
210      self.sock_diag.GetSockInfo(diag_req)
211      # No errors? Good.
212
213  def CheckFindsAllMySockets(self, socktype, proto):
214    """Tests that basic socket dumping works."""
215    self.socketpairs = self._CreateLotsOfSockets(socktype)
216    sockets = self.sock_diag.DumpAllInetSockets(proto, NO_BYTECODE)
217    self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
218
219    # Find the cookies for all of our sockets.
220    cookies = {}
221    for diag_msg, unused_attrs in sockets:
222      addr = self.sock_diag.GetSourceAddress(diag_msg)
223      sport = diag_msg.id.sport
224      dport = diag_msg.id.dport
225      if (addr, sport, dport) in self.socketpairs:
226        cookies[(addr, sport, dport)] = diag_msg.id.cookie
227      elif (addr, dport, sport) in self.socketpairs:
228        cookies[(addr, sport, dport)] = diag_msg.id.cookie
229
230    # Did we find all the cookies?
231    self.assertEqual(2 * NUM_SOCKETS, len(cookies))
232
233    socketpairs = list(self.socketpairs.values())
234    random.shuffle(socketpairs)
235    for socketpair in socketpairs:
236      for sock in socketpair:
237        # Check that we can find a diag_msg by scanning a dump.
238        self.assertSockInfoMatchesSocket(
239            sock,
240            self.sock_diag.FindSockInfoFromFd(sock))
241        cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
242
243        # Check that we can find a diag_msg once we know the cookie.
244        req = self.sock_diag.DiagReqFromSocket(sock)
245        req.id.cookie = cookie
246        if proto == IPPROTO_UDP:
247          # Kernel bug: for UDP sockets, the order of arguments must be swapped.
248          # See testDemonstrateUdpGetSockIdBug.
249          req.id.sport, req.id.dport = req.id.dport, req.id.sport
250          req.id.src, req.id.dst = req.id.dst, req.id.src
251        info = self.sock_diag.GetSockInfo(req)
252        self.assertSockInfoMatchesSocket(sock, info)
253
254  def testFindsAllMySocketsTcp(self):
255    self.CheckFindsAllMySockets(SOCK_STREAM, IPPROTO_TCP)
256
257  @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
258  def testFindsAllMySocketsUdp(self):
259    self.CheckFindsAllMySockets(SOCK_DGRAM, IPPROTO_UDP)
260
261  def testBytecodeCompilation(self):
262    # pylint: disable=bad-whitespace
263    instructions = [
264        (sock_diag.INET_DIAG_BC_S_GE,   1, 8, 0),                      # 0
265        (sock_diag.INET_DIAG_BC_D_LE,   1, 7, 0xffff),                 # 8
266        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)),       # 16
267        (sock_diag.INET_DIAG_BC_JMP,    1, 3, None),                   # 44
268        (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)),  # 48
269        (sock_diag.INET_DIAG_BC_D_LE,   1, 3, 0x6665),  # not used     # 64
270        (sock_diag.INET_DIAG_BC_NOP,    1, 1, None),                   # 72
271                                                                       # 76 acc
272                                                                       # 80 rej
273    ]
274    # pylint: enable=bad-whitespace
275    bytecode = self.PackAndCheckBytecode(instructions)
276    expected = (
277        "0208500000000000"
278        "050848000000ffff"
279        "071c20000a800000ffffffff00000000000000000000000000000001"
280        "01041c00"
281        "0718200002200000ffffffff7f000001"
282        "0508100000006566"
283        "00040400"
284    )
285    states = 1 << tcp_test.TCP_ESTABLISHED
286    self.assertMultiLineEqual(expected, bytecode.encode("hex"))
287    self.assertEqual(76, len(bytecode))
288    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
289    filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode,
290                                                        states=states)
291    allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
292                                                   states=states)
293    self.assertItemsEqual(allsockets, filteredsockets)
294
295    # Pick a few sockets in hash table order, and check that the bytecode we
296    # compiled selects them properly.
297    for socketpair in list(self.socketpairs.values())[:20]:
298      for s in socketpair:
299        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
300        instructions = [
301            (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
302            (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
303            (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
304            (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
305        ]
306        bytecode = self.PackAndCheckBytecode(instructions)
307        self.assertEqual(32, len(bytecode))
308        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
309        self.assertEqual(1, len(sockets))
310
311        # TODO: why doesn't comparing the cstructs work?
312        self.assertEqual(diag_msg.Pack(), sockets[0][0].Pack())
313
314  def testCrossFamilyBytecode(self):
315    """Checks for a cross-family bug in inet_diag_hostcond matching.
316
317    Relevant kernel commits:
318      android-3.4:
319        f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
320    """
321    # TODO: this is only here because the test fails if there are any open
322    # sockets other than the ones it creates itself. Make the bytecode more
323    # specific and remove it.
324    states = 1 << tcp_test.TCP_ESTABLISHED
325    self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, "",
326                                                       states=states))
327
328    unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
329    unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
330
331    bytecode4 = self.PackAndCheckBytecode([
332        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
333    bytecode6 = self.PackAndCheckBytecode([
334        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
335
336    # IPv4/v6 filters must never match IPv6/IPv4 sockets...
337    v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4,
338                                                  states=states)
339    self.assertTrue(v4socks)
340    self.assertTrue(all(d.family == AF_INET for d, _ in v4socks))
341
342    v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6,
343                                                  states=states)
344    self.assertTrue(v6socks)
345    self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks))
346
347    # Except for mapped addresses, which match both IPv4 and IPv6.
348    pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
349                                      "::ffff:127.0.0.1")
350    diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
351    v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
352                                                               bytecode4,
353                                                               states=states)]
354    v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
355                                                               bytecode6,
356                                                               states=states)]
357    self.assertTrue(all(d in v4socks for d in diag_msgs))
358    self.assertTrue(all(d in v6socks for d in diag_msgs))
359
360  def testPortComparisonValidation(self):
361    """Checks for a bug in validating port comparison bytecode.
362
363    Relevant kernel commits:
364      android-3.4:
365        5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads
366    """
367    bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8))
368    self.assertEqual("???",
369                      self.sock_diag.DecodeBytecode(bytecode))
370    self.assertRaisesErrno(
371        EINVAL,
372        self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
373
374  def testNonSockDiagCommand(self):
375    def DiagDump(code):
376      sock_id = self.sock_diag._EmptyInetDiagSockId()
377      req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
378                                     sock_id))
379      self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "")
380
381    op = sock_diag.SOCK_DIAG_BY_FAMILY
382    DiagDump(op)  # No errors? Good.
383    self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
384
385  def CheckSocketCookie(self, inet, addr):
386    """Tests that getsockopt SO_COOKIE can get cookie for all sockets."""
387    socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr)
388    for sock in socketpair:
389      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
390      cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
391      self.assertEqual(diag_msg.id.cookie, cookie)
392
393  @unittest.skipUnless(LINUX_4_9_OR_ABOVE, "SO_COOKIE not supported")
394  def testGetsockoptcookie(self):
395    self.CheckSocketCookie(AF_INET, "127.0.0.1")
396    self.CheckSocketCookie(AF_INET6, "::1")
397
398  @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
399  def testDemonstrateUdpGetSockIdBug(self):
400    # TODO: this is because udp_dump_one mistakenly uses __udp[46]_lib_lookup
401    # by passing the source address as the source address argument.
402    # Unfortunately those functions are intended to match local sockets based
403    # on received packets, and the argument that ends up being compared with
404    # e.g., sk_daddr is actually saddr, not daddr. udp_diag_destroy does not
405    # have this bug.  Upstream has confirmed that this will not be fixed:
406    # https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html
407    """Documents a bug: getting UDP sockets requires swapping src and dst."""
408    for version in [4, 5, 6]:
409      family = net_test.GetAddressFamily(version)
410      s = socket(family, SOCK_DGRAM, 0)
411      self.SelectInterface(s, self.RandomNetid(), "mark")
412      s.connect((self.GetRemoteSocketAddress(version), 53))
413
414      # Create a fully-specified diag req from our socket, including cookie if
415      # we can get it.
416      req = self.sock_diag.DiagReqFromSocket(s)
417      if LINUX_4_9_OR_ABOVE:
418        req.id.cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
419      else:
420        req.id.cookie = "\xff" * 16  # INET_DIAG_NOCOOKIE[2]
421
422      # As is, this request does not find anything.
423      with self.assertRaisesErrno(ENOENT):
424        self.sock_diag.GetSockInfo(req)
425
426      # But if we swap src and dst, the kernel finds our socket.
427      req.id.sport, req.id.dport = req.id.dport, req.id.sport
428      req.id.src, req.id.dst = req.id.dst, req.id.src
429
430      self.assertSockInfoMatchesSocket(s, self.sock_diag.GetSockInfo(req))
431
432
433class SockDestroyTest(SockDiagBaseTest):
434  """Tests that SOCK_DESTROY works correctly.
435
436  Relevant kernel commits:
437    net-next:
438      b613f56 net: diag: split inet_diag_dump_one_icsk into two
439      64be0ae net: diag: Add the ability to destroy a socket.
440      6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets.
441      c1e64e2 net: diag: Support destroying TCP sockets.
442      2010b93 net: tcp: deal with listen sockets properly in tcp_abort.
443
444    android-3.4:
445      d48ec88 net: diag: split inet_diag_dump_one_icsk into two
446      2438189 net: diag: Add the ability to destroy a socket.
447      7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets.
448      44047b2 net: diag: Support destroying TCP sockets.
449      200dae7 net: tcp: deal with listen sockets properly in tcp_abort.
450
451    android-3.10:
452      9eaff90 net: diag: split inet_diag_dump_one_icsk into two
453      d60326c net: diag: Add the ability to destroy a socket.
454      3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets.
455      529dfc6 net: diag: Support destroying TCP sockets.
456      9c712fe net: tcp: deal with listen sockets properly in tcp_abort.
457
458    android-3.18:
459      100263d net: diag: split inet_diag_dump_one_icsk into two
460      194c5f3 net: diag: Add the ability to destroy a socket.
461      8387ea2 net: diag: Support SOCK_DESTROY for inet sockets.
462      b80585a net: diag: Support destroying TCP sockets.
463      476c6ce net: tcp: deal with listen sockets properly in tcp_abort.
464
465    android-4.1:
466      56eebf8 net: diag: split inet_diag_dump_one_icsk into two
467      fb486c9 net: diag: Add the ability to destroy a socket.
468      0c02b7e net: diag: Support SOCK_DESTROY for inet sockets.
469      67c71d8 net: diag: Support destroying TCP sockets.
470      a76e0ec net: tcp: deal with listen sockets properly in tcp_abort.
471      e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
472
473    android-4.4:
474      76c83a9 net: diag: split inet_diag_dump_one_icsk into two
475      f7cf791 net: diag: Add the ability to destroy a socket.
476      1c42248 net: diag: Support SOCK_DESTROY for inet sockets.
477      c9e8440d net: diag: Support destroying TCP sockets.
478      3d9502c tcp: diag: add support for request sockets to tcp_abort()
479      001cf75 net: tcp: deal with listen sockets properly in tcp_abort.
480  """
481
482  def testClosesSockets(self):
483    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
484    for _, socketpair in self.socketpairs.items():
485      # Close one of the sockets.
486      # This will send a RST that will close the other side as well.
487      s = random.choice(socketpair)
488      if random.randrange(0, 2) == 1:
489        self.sock_diag.CloseSocketFromFd(s)
490      else:
491        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
492
493        # Get the cookie wrong and ensure that we get an error and the socket
494        # is not closed.
495        real_cookie = diag_msg.id.cookie
496        diag_msg.id.cookie = os.urandom(len(real_cookie))
497        req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
498        self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
499        self.assertSocketConnected(s)
500
501        # Now close it with the correct cookie.
502        req.id.cookie = real_cookie
503        self.sock_diag.CloseSocket(req)
504
505      # Check that both sockets in the pair are closed.
506      self.assertSocketsClosed(socketpair)
507
508  # TODO:
509  # Test that killing unix sockets returns EOPNOTSUPP.
510
511
512class SocketExceptionThread(threading.Thread):
513
514  def __init__(self, sock, operation):
515    self.exception = None
516    super(SocketExceptionThread, self).__init__()
517    self.daemon = True
518    self.sock = sock
519    self.operation = operation
520
521  def run(self):
522    try:
523      self.operation(self.sock)
524    except (IOError, AssertionError) as e:
525      self.exception = e
526
527
528class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
529
530  def testIpv4MappedSynRecvSocket(self):
531    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
532
533    Relevant kernel commits:
534         android-3.4:
535           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
536    """
537    netid = random.choice(list(self.tuns.keys()))
538    self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
539    sock_id = self.sock_diag._EmptyInetDiagSockId()
540    sock_id.sport = self.port
541    states = 1 << tcp_test.TCP_SYN_RECV
542    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
543    children = self.sock_diag.Dump(req, NO_BYTECODE)
544
545    self.assertTrue(children)
546    for child, unused_args in children:
547      self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
548      self.assertEqual(self.sock_diag.PaddedAddress(self.remotesockaddr),
549                       child.id.dst)
550      self.assertEqual(self.sock_diag.PaddedAddress(self.mysockaddr),
551                       child.id.src)
552
553
554class TcpRcvWindowTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
555
556  RWND_SIZE = 64000 if LINUX_4_19_OR_ABOVE else 42000
557  TCP_DEFAULT_INIT_RWND = "/proc/sys/net/ipv4/tcp_default_init_rwnd"
558
559  def setUp(self):
560    super(TcpRcvWindowTest, self).setUp()
561    if LINUX_4_19_OR_ABOVE:
562      self.assertRaisesErrno(ENOENT, open, self.TCP_DEFAULT_INIT_RWND, "w")
563      return
564
565    f = open(self.TCP_DEFAULT_INIT_RWND, "w")
566    f.write("60")
567
568  def checkInitRwndSize(self, version, netid):
569    self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid)
570    tcpInfo = TcpInfo(self.accepted.getsockopt(net_test.SOL_TCP,
571                                               net_test.TCP_INFO, len(TcpInfo)))
572    self.assertLess(self.RWND_SIZE, tcpInfo.tcpi_rcv_ssthresh,
573                    "Tcp rwnd of netid=%d, version=%d is not enough. "
574                    "Expect: %d, actual: %d" % (netid, version, self.RWND_SIZE,
575                                                tcpInfo.tcpi_rcv_ssthresh))
576
577  def checkSynPacketWindowSize(self, version, netid):
578    s = self.BuildSocket(version, net_test.TCPSocket, netid, "mark")
579    myaddr = self.MyAddress(version, netid)
580    dstaddr = self.GetRemoteAddress(version)
581    dstsockaddr = self.GetRemoteSocketAddress(version)
582    desc, expected = packets.SYN(53, version, myaddr, dstaddr,
583                                 sport=None, seq=None)
584    self.assertRaisesErrno(EINPROGRESS, s.connect, (dstsockaddr, 53))
585    msg = "IPv%s TCP connect: expected %s on %s" % (
586        version, desc, self.GetInterfaceName(netid))
587    syn = self.ExpectPacketOn(netid, msg, expected)
588    self.assertLess(self.RWND_SIZE, syn.window)
589    s.close()
590
591  def testTcpCwndSize(self):
592    for version in [4, 5, 6]:
593      for netid in self.NETIDS:
594        self.checkInitRwndSize(version, netid)
595        self.checkSynPacketWindowSize(version, netid)
596
597
598class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
599
600  def setUp(self):
601    super(SockDestroyTcpTest, self).setUp()
602    self.netid = random.choice(list(self.tuns.keys()))
603
604  def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
605    """Closes the socket and checks whether a RST is sent or not."""
606    if sock is not None:
607      self.assertIsNone(req, "Must specify sock or req, not both")
608      self.sock_diag.CloseSocketFromFd(sock)
609      self.assertRaisesErrno(EINVAL, sock.accept)
610    else:
611      self.assertIsNone(sock, "Must specify sock or req, not both")
612      self.sock_diag.CloseSocket(req)
613
614    if expect_reset:
615      desc, rst = self.RstPacket()
616      msg = "%s: expecting %s: " % (msg, desc)
617      self.ExpectPacketOn(self.netid, msg, rst)
618    else:
619      msg = "%s: " % msg
620      self.ExpectNoPacketsOn(self.netid, msg)
621
622    if sock is not None and do_close:
623      sock.close()
624
625  def CheckTcpReset(self, state, statename):
626    for version in [4, 5, 6]:
627      msg = "Closing incoming IPv%d %s socket" % (version, statename)
628      self.IncomingConnection(version, state, self.netid)
629      self.CheckRstOnClose(self.s, None, False, msg)
630      if state != tcp_test.TCP_LISTEN:
631        msg = "Closing accepted IPv%d %s socket" % (version, statename)
632        self.CheckRstOnClose(self.accepted, None, True, msg)
633
634  def testTcpResets(self):
635    """Checks that closing sockets in appropriate states sends a RST."""
636    self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
637    self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
638    self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
639
640  def testFinWait1Socket(self):
641    for version in [4, 5, 6]:
642      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
643
644      # Get the cookie so we can find this socket after we close it.
645      diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted)
646      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
647
648      # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN.
649      net_test.EnableFinWait(self.accepted)
650      self.accepted.close()
651      diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1
652      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
653      self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
654      desc, fin = self.FinPacket()
655      self.ExpectPacketOn(self.netid, "Closing FIN_WAIT1 socket", fin)
656
657      # Destroy the socket and expect no RST.
658      self.CheckRstOnClose(None, diag_req, False, "Closing FIN_WAIT1 socket")
659      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
660
661      # The socket is still there in FIN_WAIT1: SOCK_DESTROY did nothing
662      # because userspace had already closed it.
663      self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
664
665      # ACK the FIN so we don't trip over retransmits in future tests.
666      finversion = 4 if version == 5 else version
667      desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin)
668      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
669      self.ReceivePacketOn(self.netid, finack)
670
671      # See if we can find the resulting FIN_WAIT2 socket. This does not appear
672      # to work on 3.10.
673      if net_test.LINUX_VERSION >= (3, 18):
674        diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2
675        infos = self.sock_diag.Dump(diag_req, "")
676        self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2
677                            for diag_msg, attrs in infos),
678                        "Expected to find FIN_WAIT2 socket in %s" % infos)
679
680  def FindChildSockets(self, s):
681    """Finds the SYN_RECV child sockets of a given listening socket."""
682    d = self.sock_diag.FindSockDiagFromFd(self.s)
683    req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
684    req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED
685    req.id.cookie = "\x00" * 8
686
687    bad_bytecode = self.PackAndCheckBytecode(
688        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))])
689    self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode))
690
691    bytecode = self.PackAndCheckBytecode(
692        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))])
693    children = self.sock_diag.Dump(req, bytecode)
694    return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
695            for d, _ in children]
696
697  def CheckChildSocket(self, version, statename, parent_first):
698    state = getattr(tcp_test, statename)
699
700    self.IncomingConnection(version, state, self.netid)
701
702    d = self.sock_diag.FindSockDiagFromFd(self.s)
703    parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
704    children = self.FindChildSockets(self.s)
705    self.assertEqual(1, len(children))
706
707    is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED)
708    expected_state = tcp_test.TCP_ESTABLISHED if is_established else state
709
710    # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
711    # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
712    # Before 4.4, we can see those sockets in dumps, but we can't fetch
713    # or close them.
714    can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
715
716    for child in children:
717      if can_close_children:
718        diag_msg, attrs = self.sock_diag.GetSockInfo(child)
719        self.assertEqual(diag_msg.state, expected_state)
720        self.assertMarkIs(self.netid, attrs)
721      else:
722        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
723
724    def CloseParent(expect_reset):
725      msg = "Closing parent IPv%d %s socket %s child" % (
726          version, statename, "before" if parent_first else "after")
727      self.CheckRstOnClose(self.s, None, expect_reset, msg)
728      self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent)
729
730    def CheckChildrenClosed():
731      for child in children:
732        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
733
734    def CloseChildren():
735      for child in children:
736        msg = "Closing child IPv%d %s socket %s parent" % (
737            version, statename, "after" if parent_first else "before")
738        self.sock_diag.GetSockInfo(child)
739        self.CheckRstOnClose(None, child, is_established, msg)
740        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
741      CheckChildrenClosed()
742
743    if parent_first:
744      # Closing the parent will close child sockets, which will send a RST,
745      # iff they are already established.
746      CloseParent(is_established)
747      if is_established:
748        CheckChildrenClosed()
749      elif can_close_children:
750        CloseChildren()
751        CheckChildrenClosed()
752      self.s.close()
753    else:
754      if can_close_children:
755        CloseChildren()
756      CloseParent(False)
757      self.s.close()
758
759  def testChildSockets(self):
760    for version in [4, 5, 6]:
761      self.CheckChildSocket(version, "TCP_SYN_RECV", False)
762      self.CheckChildSocket(version, "TCP_SYN_RECV", True)
763      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False)
764      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True)
765
766  def testAcceptInterrupted(self):
767    """Tests that accept() is interrupted by SOCK_DESTROY."""
768    for version in [4, 5, 6]:
769      self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
770      self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096)
771      self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
772      self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
773      self.assertRaisesErrno(EINVAL, self.s.accept)
774      # TODO: this should really return an error such as ENOTCONN...
775      self.assertEqual("", self.s.recv(4096))
776
777  def testReadInterrupted(self):
778    """Tests that read() is interrupted by SOCK_DESTROY."""
779    for version in [4, 5, 6]:
780      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
781      self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
782                                   ECONNABORTED)
783      # Writing returns EPIPE, and reading returns EOF.
784      self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
785      self.assertEqual("", self.accepted.recv(4096))
786      self.assertEqual("", self.accepted.recv(4096))
787
788  def testConnectInterrupted(self):
789    """Tests that connect() is interrupted by SOCK_DESTROY."""
790    for version in [4, 5, 6]:
791      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
792      s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
793      self.SelectInterface(s, self.netid, "mark")
794
795      remotesockaddr = self.GetRemoteSocketAddress(version)
796      remoteaddr = self.GetRemoteAddress(version)
797      s.bind(("", 0))
798      _, sport = s.getsockname()[:2]
799      self.CloseDuringBlockingCall(
800          s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED)
801      desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
802                              remoteaddr, sport=sport, seq=None)
803      self.ExpectPacketOn(self.netid, desc, syn)
804      msg = "SOCK_DESTROY of socket in connect, expected no RST"
805      self.ExpectNoPacketsOn(self.netid, msg)
806
807
808class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
809  """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs.
810
811  The behaviour of poll() in these cases is not what we might expect: if only
812  POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT
813  is (also) specified, it will only return POLLOUT.
814  """
815
816  POLLIN_OUT = select.POLLIN | select.POLLOUT
817  POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP
818
819  def setUp(self):
820    super(PollOnCloseTest, self).setUp()
821    self.netid = random.choice(list(self.tuns.keys()))
822
823  POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"),
824                (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")]
825
826  def PollResultToString(self, poll_events, ignoremask):
827    out = []
828    for fd, event in poll_events:
829      flags = [name for (flag, name) in self.POLL_FLAGS
830               if event & flag & ~ignoremask != 0]
831      out.append((fd, "|".join(flags)))
832    return out
833
834  def BlockingPoll(self, sock, mask, expected, ignoremask):
835    p = select.poll()
836    p.register(sock, mask)
837    expected_fds = [(sock.fileno(), expected)]
838    # Don't block forever or we'll hang continuous test runs on failure.
839    # A 5-second timeout should be long enough not to be flaky.
840    actual_fds = p.poll(5000)
841    self.assertEqual(self.PollResultToString(expected_fds, ignoremask),
842                     self.PollResultToString(actual_fds, ignoremask))
843
844  def RstDuringBlockingCall(self, sock, call, expected_errno):
845    self._EventDuringBlockingCall(
846        sock, call, expected_errno,
847        lambda _: self.ReceiveRstPacketOn(self.netid))
848
849  def assertSocketErrors(self, errno):
850    # The first operation returns the expected errno.
851    self.assertRaisesErrno(errno, self.accepted.recv, 4096)
852
853    # Subsequent operations behave as normal.
854    self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
855    self.assertEqual("", self.accepted.recv(4096))
856    self.assertEqual("", self.accepted.recv(4096))
857
858  def CheckPollDestroy(self, mask, expected, ignoremask):
859    """Interrupts a poll() with SOCK_DESTROY."""
860    for version in [4, 5, 6]:
861      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
862      self.CloseDuringBlockingCall(
863          self.accepted,
864          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
865          None)
866      self.assertSocketErrors(ECONNABORTED)
867
868  def CheckPollRst(self, mask, expected, ignoremask):
869    """Interrupts a poll() by receiving a TCP RST."""
870    for version in [4, 5, 6]:
871      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
872      self.RstDuringBlockingCall(
873          self.accepted,
874          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
875          None)
876      self.assertSocketErrors(ECONNRESET)
877
878  def testReadPollRst(self):
879    # Until 3d4762639d ("tcp: remove poll() flakes when receiving RST"), poll()
880    # would sometimes return POLLERR and sometimes POLLIN|POLLERR|POLLHUP. This
881    # is due to a race inside the kernel and thus is not visible on the VM, only
882    # on physical hardware.
883    if net_test.LINUX_VERSION < (4, 14, 0):
884      ignoremask = select.POLLIN | select.POLLHUP
885    else:
886      ignoremask = 0
887    self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
888
889  def testWritePollRst(self):
890    self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0)
891
892  def testReadWritePollRst(self):
893    self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0)
894
895  def testReadPollDestroy(self):
896    # tcp_abort has the same race that tcp_reset has, but it's not fixed yet.
897    ignoremask = select.POLLIN | select.POLLHUP
898    self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
899
900  def testWritePollDestroy(self):
901    self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0)
902
903  def testReadWritePollDestroy(self):
904    self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0)
905
906
907@unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
908class SockDestroyUdpTest(SockDiagBaseTest):
909
910  """Tests SOCK_DESTROY on UDP sockets.
911
912    Relevant kernel commits:
913      upstream net-next:
914        5d77dca net: diag: support SOCK_DESTROY for UDP sockets
915        f95bf34 net: diag: make udp_diag_destroy work for mapped addresses.
916  """
917
918  def testClosesUdpSockets(self):
919    self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM)
920    for _, socketpair in self.socketpairs.items():
921      s1, s2 = socketpair
922
923      self.assertSocketConnected(s1)
924      self.sock_diag.CloseSocketFromFd(s1)
925      self.assertSocketClosed(s1)
926
927      self.assertSocketConnected(s2)
928      self.sock_diag.CloseSocketFromFd(s2)
929      self.assertSocketClosed(s2)
930
931  def BindToRandomPort(self, s, addr):
932    ATTEMPTS = 20
933    for i in range(20):
934      port = random.randrange(1024, 65535)
935      try:
936        s.bind((addr, port))
937        return port
938      except error as e:
939        if e.errno != EADDRINUSE:
940          raise e
941    raise ValueError("Could not find a free port on %s after %d attempts" %
942                     (addr, ATTEMPTS))
943
944  def testSocketAddressesAfterClose(self):
945    for version in 4, 5, 6:
946      netid = random.choice(self.NETIDS)
947      dst = self.GetRemoteSocketAddress(version)
948      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
949      unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
950
951      # Closing a socket that was not explicitly bound (i.e., bound via
952      # connect(), not bind()) clears the source address and port.
953      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
954      self.SelectInterface(s, netid, "mark")
955      s.connect((dst, 53))
956      self.sock_diag.CloseSocketFromFd(s)
957      self.assertEqual((unspec, 0), s.getsockname()[:2])
958
959      # Closing a socket bound to an IP address leaves the address as is.
960      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
961      src = self.MySocketAddress(version, netid)
962      s.bind((src, 0))
963      s.connect((dst, 53))
964      port = s.getsockname()[1]
965      self.sock_diag.CloseSocketFromFd(s)
966      self.assertEqual((src, 0), s.getsockname()[:2])
967
968      # Closing a socket bound to a port leaves the port as is.
969      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
970      port = self.BindToRandomPort(s, "")
971      s.connect((dst, 53))
972      self.sock_diag.CloseSocketFromFd(s)
973      self.assertEqual((unspec, port), s.getsockname()[:2])
974
975      # Closing a socket bound to IP address and port leaves both as is.
976      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
977      src = self.MySocketAddress(version, netid)
978      port = self.BindToRandomPort(s, src)
979      self.sock_diag.CloseSocketFromFd(s)
980      self.assertEqual((src, port), s.getsockname()[:2])
981
982  def testReadInterrupted(self):
983    """Tests that read() is interrupted by SOCK_DESTROY."""
984    for version in [4, 5, 6]:
985      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
986      s = net_test.UDPSocket(family)
987      self.SelectInterface(s, random.choice(self.NETIDS), "mark")
988      addr = self.GetRemoteSocketAddress(version)
989
990      # Check that reads on connected sockets are interrupted.
991      s.connect((addr, 53))
992      self.assertEqual(3, s.send("foo"))
993      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
994                                   ECONNABORTED)
995
996      # A destroyed socket is no longer connected, but still usable.
997      self.assertRaisesErrno(EDESTADDRREQ, s.send, "foo")
998      self.assertEqual(3, s.sendto("foo", (addr, 53)))
999
1000      # Check that reads on unconnected sockets are also interrupted.
1001      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
1002                                   ECONNABORTED)
1003
1004class SockDestroyPermissionTest(SockDiagBaseTest):
1005
1006  def CheckPermissions(self, socktype):
1007    s = socket(AF_INET6, socktype, 0)
1008    self.SelectInterface(s, random.choice(self.NETIDS), "mark")
1009    if socktype == SOCK_STREAM:
1010      s.listen(1)
1011      expectedstate = tcp_test.TCP_LISTEN
1012    else:
1013      s.connect((self.GetRemoteAddress(6), 53))
1014      expectedstate = tcp_test.TCP_ESTABLISHED
1015
1016    with net_test.RunAsUid(12345):
1017      self.assertRaisesErrno(
1018          EPERM, self.sock_diag.CloseSocketFromFd, s)
1019
1020    self.sock_diag.CloseSocketFromFd(s)
1021    self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s)
1022
1023
1024  @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
1025  def testUdp(self):
1026    self.CheckPermissions(SOCK_DGRAM)
1027
1028  def testTcp(self):
1029    self.CheckPermissions(SOCK_STREAM)
1030
1031
1032class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
1033
1034  """Tests SOCK_DIAG bytecode filters that use marks.
1035
1036    Relevant kernel commits:
1037      upstream net-next:
1038        627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks.
1039        a52e95a net: diag: allow socket bytecode filters to match socket marks
1040        d545cac net: inet: diag: expose the socket mark to privileged processes.
1041  """
1042
1043  def FilterEstablishedSockets(self, mark, mask):
1044    instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))]
1045    bytecode = self.sock_diag.PackBytecode(instructions)
1046    return self.sock_diag.DumpAllInetSockets(
1047        IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED))
1048
1049  def assertSamePorts(self, ports, diag_msgs):
1050    expected = sorted(ports)
1051    actual = sorted([msg[0].id.sport for msg in diag_msgs])
1052    self.assertEqual(expected, actual)
1053
1054  def SockInfoMatchesSocket(self, s, info):
1055    try:
1056      self.assertSockInfoMatchesSocket(s, info)
1057      return True
1058    except AssertionError:
1059      return False
1060
1061  @staticmethod
1062  def SocketDescription(s):
1063    return "%s -> %s" % (str(s.getsockname()), str(s.getpeername()))
1064
1065  def assertFoundSockets(self, infos, sockets):
1066    matches = {}
1067    for s in sockets:
1068      match = None
1069      for info in infos:
1070        if self.SockInfoMatchesSocket(s, info):
1071          if match:
1072            self.fail("Socket %s matched both %s and %s" %
1073                      (self.SocketDescription(s), match, info))
1074          matches[s] = info
1075      self.assertTrue(s in matches, "Did not find socket %s in dump" %
1076                      self.SocketDescription(s))
1077
1078    for i in infos:
1079       if i not in list(matches.values()):
1080         self.fail("Too many sockets in dump, first unexpected: %s" % str(i))
1081
1082  def testMarkBytecode(self):
1083    family, addr = random.choice([
1084        (AF_INET, "127.0.0.1"),
1085        (AF_INET6, "::1"),
1086        (AF_INET6, "::ffff:127.0.0.1")])
1087    s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
1088    s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234)
1089    s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235)
1090
1091    infos = self.FilterEstablishedSockets(0x1234, 0xffff)
1092    self.assertFoundSockets(infos, [s1])
1093
1094    infos = self.FilterEstablishedSockets(0x1234, 0xfffe)
1095    self.assertFoundSockets(infos, [s1, s2])
1096
1097    infos = self.FilterEstablishedSockets(0x1235, 0xffff)
1098    self.assertFoundSockets(infos, [s2])
1099
1100    infos = self.FilterEstablishedSockets(0x0, 0x0)
1101    self.assertFoundSockets(infos, [s1, s2])
1102
1103    infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00)
1104    self.assertEqual(0, len(infos))
1105
1106    with net_test.RunAsUid(12345):
1107        self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets,
1108                               0xfff0000, 0xf0fed00)
1109
1110  @staticmethod
1111  def SetRandomMark(s):
1112    # Python doesn't like marks that don't fit into a signed int.
1113    mark = random.randrange(0, 2**31 - 1)
1114    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark)
1115    return mark
1116
1117  def assertSocketMarkIs(self, s, mark):
1118    diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
1119    self.assertMarkIs(mark, attrs)
1120    with net_test.RunAsUid(12345):
1121      diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
1122      self.assertMarkIs(None, attrs)
1123
1124  def testMarkInAttributes(self):
1125    testcases = [(AF_INET, "127.0.0.1"),
1126                 (AF_INET6, "::1"),
1127                 (AF_INET6, "::ffff:127.0.0.1")]
1128    for family, addr in testcases:
1129      # TCP listen sockets.
1130      server = socket(family, SOCK_STREAM, 0)
1131      server.bind((addr, 0))
1132      port = server.getsockname()[1]
1133      server.listen(1)  # Or the socket won't be in the hashtables.
1134      server_mark = self.SetRandomMark(server)
1135      self.assertSocketMarkIs(server, server_mark)
1136
1137      # TCP client sockets.
1138      client = socket(family, SOCK_STREAM, 0)
1139      client_mark = self.SetRandomMark(client)
1140      client.connect((addr, port))
1141      self.assertSocketMarkIs(client, client_mark)
1142
1143      # TCP server sockets.
1144      accepted, _ = server.accept()
1145      self.assertSocketMarkIs(accepted, server_mark)
1146
1147      accepted_mark = self.SetRandomMark(accepted)
1148      self.assertSocketMarkIs(accepted, accepted_mark)
1149      self.assertSocketMarkIs(server, server_mark)
1150
1151      server.close()
1152      client.close()
1153
1154      # Other TCP states are tested in SockDestroyTcpTest.
1155
1156      # UDP sockets.
1157      if HAVE_UDP_DIAG:
1158        s = socket(family, SOCK_DGRAM, 0)
1159        mark = self.SetRandomMark(s)
1160        s.connect(("", 53))
1161        self.assertSocketMarkIs(s, mark)
1162        s.close()
1163
1164      # Basic test for SCTP. sctp_diag was only added in 4.7.
1165      if HAVE_SCTP:
1166        s = socket(family, SOCK_STREAM, IPPROTO_SCTP)
1167        s.bind((addr, 0))
1168        s.listen(1)
1169        mark = self.SetRandomMark(s)
1170        self.assertSocketMarkIs(s, mark)
1171        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE)
1172        self.assertEqual(1, len(sockets))
1173        self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None))
1174        s.close()
1175
1176
1177if __name__ == "__main__":
1178  unittest.main()
1179