1#!/usr/bin/python
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
18from errno import *  # pylint: disable=wildcard-import
19import os
20import random
21import re
22from socket import *  # pylint: disable=wildcard-import
23import threading
24import time
25import unittest
26
27import multinetwork_base
28import net_test
29import netlink
30import packets
31import sock_diag
32import tcp_test
33
34
35NUM_SOCKETS = 30
36NO_BYTECODE = ""
37
38
39class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
40  """Basic tests for SOCK_DIAG functionality.
41
42    Relevant kernel commits:
43      android-3.4:
44        ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields
45        99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
46
47      android-3.10:
48        3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields
49        f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
50
51      android-3.18:
52        e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
53
54      android-4.4:
55        525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
56  """
57  @staticmethod
58  def _CreateLotsOfSockets(socktype):
59    # Dict mapping (addr, sport, dport) tuples to socketpairs.
60    socketpairs = {}
61    for _ in xrange(NUM_SOCKETS):
62      family, addr = random.choice([
63          (AF_INET, "127.0.0.1"),
64          (AF_INET6, "::1"),
65          (AF_INET6, "::ffff:127.0.0.1")])
66      socketpair = net_test.CreateSocketPair(family, socktype, addr)
67      sport, dport = (socketpair[0].getsockname()[1],
68                      socketpair[1].getsockname()[1])
69      socketpairs[(addr, sport, dport)] = socketpair
70    return socketpairs
71
72  def assertSocketClosed(self, sock):
73    self.assertRaisesErrno(ENOTCONN, sock.getpeername)
74
75  def assertSocketConnected(self, sock):
76    sock.getpeername()  # No errors? Socket is alive and connected.
77
78  def assertSocketsClosed(self, socketpair):
79    for sock in socketpair:
80      self.assertSocketClosed(sock)
81
82  def assertMarkIs(self, mark, attrs):
83    self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None))
84
85  def assertSockInfoMatchesSocket(self, s, info):
86    diag_msg, attrs = info
87    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
88    self.assertEqual(diag_msg.family, family)
89
90    src, sport = s.getsockname()[0:2]
91    self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
92    self.assertEqual(diag_msg.id.sport, sport)
93
94    if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
95      dst, dport = s.getpeername()[0:2]
96      self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
97      self.assertEqual(diag_msg.id.dport, dport)
98    else:
99      self.assertRaisesErrno(ENOTCONN, s.getpeername)
100
101    mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
102    self.assertMarkIs(mark, attrs)
103
104  def PackAndCheckBytecode(self, instructions):
105    bytecode = self.sock_diag.PackBytecode(instructions)
106    decoded = self.sock_diag.DecodeBytecode(bytecode)
107    self.assertEquals(len(instructions), len(decoded))
108    self.assertFalse("???" in decoded)
109    return bytecode
110
111  def CloseDuringBlockingCall(self, sock, call, expected_errno):
112    thread = SocketExceptionThread(sock, call)
113    thread.start()
114    time.sleep(0.1)
115    self.sock_diag.CloseSocketFromFd(sock)
116    thread.join(1)
117    self.assertFalse(thread.is_alive())
118    self.assertIsNotNone(thread.exception)
119    self.assertTrue(isinstance(thread.exception, IOError),
120                    "Expected IOError, got %s" % thread.exception)
121    self.assertEqual(expected_errno, thread.exception.errno)
122    self.assertSocketClosed(sock)
123
124  def setUp(self):
125    super(SockDiagBaseTest, self).setUp()
126    self.sock_diag = sock_diag.SockDiag()
127    self.socketpairs = {}
128
129  def tearDown(self):
130    for socketpair in self.socketpairs.values():
131      for s in socketpair:
132        s.close()
133    super(SockDiagBaseTest, self).tearDown()
134
135
136class SockDiagTest(SockDiagBaseTest):
137
138  def testFindsMappedSockets(self):
139    """Tests that inet_diag_find_one_icsk can find mapped sockets."""
140    socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
141                                           "::ffff:127.0.0.1")
142    for sock in socketpair:
143      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
144      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
145      self.sock_diag.GetSockInfo(diag_req)
146      # No errors? Good.
147
148  def testFindsAllMySockets(self):
149    """Tests that basic socket dumping works."""
150    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
151    sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
152    self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
153
154    # Find the cookies for all of our sockets.
155    cookies = {}
156    for diag_msg, unused_attrs in sockets:
157      addr = self.sock_diag.GetSourceAddress(diag_msg)
158      sport = diag_msg.id.sport
159      dport = diag_msg.id.dport
160      if (addr, sport, dport) in self.socketpairs:
161        cookies[(addr, sport, dport)] = diag_msg.id.cookie
162      elif (addr, dport, sport) in self.socketpairs:
163        cookies[(addr, sport, dport)] = diag_msg.id.cookie
164
165    # Did we find all the cookies?
166    self.assertEquals(2 * NUM_SOCKETS, len(cookies))
167
168    socketpairs = self.socketpairs.values()
169    random.shuffle(socketpairs)
170    for socketpair in socketpairs:
171      for sock in socketpair:
172        # Check that we can find a diag_msg by scanning a dump.
173        self.assertSockInfoMatchesSocket(
174            sock,
175            self.sock_diag.FindSockInfoFromFd(sock))
176        cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
177
178        # Check that we can find a diag_msg once we know the cookie.
179        req = self.sock_diag.DiagReqFromSocket(sock)
180        req.id.cookie = cookie
181        info = self.sock_diag.GetSockInfo(req)
182        self.assertSockInfoMatchesSocket(sock, info)
183
184  def testBytecodeCompilation(self):
185    # pylint: disable=bad-whitespace
186    instructions = [
187        (sock_diag.INET_DIAG_BC_S_GE,   1, 8, 0),                      # 0
188        (sock_diag.INET_DIAG_BC_D_LE,   1, 7, 0xffff),                 # 8
189        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)),       # 16
190        (sock_diag.INET_DIAG_BC_JMP,    1, 3, None),                   # 44
191        (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)),  # 48
192        (sock_diag.INET_DIAG_BC_D_LE,   1, 3, 0x6665),  # not used     # 64
193        (sock_diag.INET_DIAG_BC_NOP,    1, 1, None),                   # 72
194                                                                       # 76 acc
195                                                                       # 80 rej
196    ]
197    # pylint: enable=bad-whitespace
198    bytecode = self.PackAndCheckBytecode(instructions)
199    expected = (
200        "0208500000000000"
201        "050848000000ffff"
202        "071c20000a800000ffffffff00000000000000000000000000000001"
203        "01041c00"
204        "0718200002200000ffffffff7f000001"
205        "0508100000006566"
206        "00040400"
207    )
208    states = 1 << tcp_test.TCP_ESTABLISHED
209    self.assertMultiLineEqual(expected, bytecode.encode("hex"))
210    self.assertEquals(76, len(bytecode))
211    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
212    filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode,
213                                                        states=states)
214    allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
215                                                   states=states)
216    self.assertItemsEqual(allsockets, filteredsockets)
217
218    # Pick a few sockets in hash table order, and check that the bytecode we
219    # compiled selects them properly.
220    for socketpair in self.socketpairs.values()[:20]:
221      for s in socketpair:
222        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
223        instructions = [
224            (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
225            (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
226            (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
227            (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
228        ]
229        bytecode = self.PackAndCheckBytecode(instructions)
230        self.assertEquals(32, len(bytecode))
231        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
232        self.assertEquals(1, len(sockets))
233
234        # TODO: why doesn't comparing the cstructs work?
235        self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack())
236
237  def testCrossFamilyBytecode(self):
238    """Checks for a cross-family bug in inet_diag_hostcond matching.
239
240    Relevant kernel commits:
241      android-3.4:
242        f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
243    """
244    # TODO: this is only here because the test fails if there are any open
245    # sockets other than the ones it creates itself. Make the bytecode more
246    # specific and remove it.
247    states = 1 << tcp_test.TCP_ESTABLISHED
248    self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, "",
249                                                       states=states))
250
251    unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
252    unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
253
254    bytecode4 = self.PackAndCheckBytecode([
255        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
256    bytecode6 = self.PackAndCheckBytecode([
257        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
258
259    # IPv4/v6 filters must never match IPv6/IPv4 sockets...
260    v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4,
261                                                  states=states)
262    self.assertTrue(v4socks)
263    self.assertTrue(all(d.family == AF_INET for d, _ in v4socks))
264
265    v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6,
266                                                  states=states)
267    self.assertTrue(v6socks)
268    self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks))
269
270    # Except for mapped addresses, which match both IPv4 and IPv6.
271    pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
272                                      "::ffff:127.0.0.1")
273    diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
274    v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
275                                                               bytecode4,
276                                                               states=states)]
277    v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
278                                                               bytecode6,
279                                                               states=states)]
280    self.assertTrue(all(d in v4socks for d in diag_msgs))
281    self.assertTrue(all(d in v6socks for d in diag_msgs))
282
283  def testPortComparisonValidation(self):
284    """Checks for a bug in validating port comparison bytecode.
285
286    Relevant kernel commits:
287      android-3.4:
288        5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads
289    """
290    bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8))
291    self.assertEquals("???",
292                      self.sock_diag.DecodeBytecode(bytecode))
293    self.assertRaisesErrno(
294        EINVAL,
295        self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
296
297  def testNonSockDiagCommand(self):
298    def DiagDump(code):
299      sock_id = self.sock_diag._EmptyInetDiagSockId()
300      req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
301                                     sock_id))
302      self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "")
303
304    op = sock_diag.SOCK_DIAG_BY_FAMILY
305    DiagDump(op)  # No errors? Good.
306    self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
307
308
309class SockDestroyTest(SockDiagBaseTest):
310  """Tests that SOCK_DESTROY works correctly.
311
312  Relevant kernel commits:
313    net-next:
314      b613f56 net: diag: split inet_diag_dump_one_icsk into two
315      64be0ae net: diag: Add the ability to destroy a socket.
316      6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets.
317      c1e64e2 net: diag: Support destroying TCP sockets.
318      2010b93 net: tcp: deal with listen sockets properly in tcp_abort.
319
320    android-3.4:
321      d48ec88 net: diag: split inet_diag_dump_one_icsk into two
322      2438189 net: diag: Add the ability to destroy a socket.
323      7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets.
324      44047b2 net: diag: Support destroying TCP sockets.
325      200dae7 net: tcp: deal with listen sockets properly in tcp_abort.
326
327    android-3.10:
328      9eaff90 net: diag: split inet_diag_dump_one_icsk into two
329      d60326c net: diag: Add the ability to destroy a socket.
330      3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets.
331      529dfc6 net: diag: Support destroying TCP sockets.
332      9c712fe net: tcp: deal with listen sockets properly in tcp_abort.
333
334    android-3.18:
335      100263d net: diag: split inet_diag_dump_one_icsk into two
336      194c5f3 net: diag: Add the ability to destroy a socket.
337      8387ea2 net: diag: Support SOCK_DESTROY for inet sockets.
338      b80585a net: diag: Support destroying TCP sockets.
339      476c6ce net: tcp: deal with listen sockets properly in tcp_abort.
340
341    android-4.1:
342      56eebf8 net: diag: split inet_diag_dump_one_icsk into two
343      fb486c9 net: diag: Add the ability to destroy a socket.
344      0c02b7e net: diag: Support SOCK_DESTROY for inet sockets.
345      67c71d8 net: diag: Support destroying TCP sockets.
346      a76e0ec net: tcp: deal with listen sockets properly in tcp_abort.
347      e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
348
349    android-4.4:
350      76c83a9 net: diag: split inet_diag_dump_one_icsk into two
351      f7cf791 net: diag: Add the ability to destroy a socket.
352      1c42248 net: diag: Support SOCK_DESTROY for inet sockets.
353      c9e8440d net: diag: Support destroying TCP sockets.
354      3d9502c tcp: diag: add support for request sockets to tcp_abort()
355      001cf75 net: tcp: deal with listen sockets properly in tcp_abort.
356  """
357
358  def testClosesSockets(self):
359    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
360    for _, socketpair in self.socketpairs.iteritems():
361      # Close one of the sockets.
362      # This will send a RST that will close the other side as well.
363      s = random.choice(socketpair)
364      if random.randrange(0, 2) == 1:
365        self.sock_diag.CloseSocketFromFd(s)
366      else:
367        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
368
369        # Get the cookie wrong and ensure that we get an error and the socket
370        # is not closed.
371        real_cookie = diag_msg.id.cookie
372        diag_msg.id.cookie = os.urandom(len(real_cookie))
373        req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
374        self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
375        self.assertSocketConnected(s)
376
377        # Now close it with the correct cookie.
378        req.id.cookie = real_cookie
379        self.sock_diag.CloseSocket(req)
380
381      # Check that both sockets in the pair are closed.
382      self.assertSocketsClosed(socketpair)
383
384  # TODO:
385  # Test that killing unix sockets returns EOPNOTSUPP.
386
387
388class SocketExceptionThread(threading.Thread):
389
390  def __init__(self, sock, operation):
391    self.exception = None
392    super(SocketExceptionThread, self).__init__()
393    self.daemon = True
394    self.sock = sock
395    self.operation = operation
396
397  def run(self):
398    try:
399      self.operation(self.sock)
400    except IOError, e:
401      self.exception = e
402
403
404class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
405
406  def testIpv4MappedSynRecvSocket(self):
407    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
408
409    Relevant kernel commits:
410         android-3.4:
411           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
412    """
413    netid = random.choice(self.tuns.keys())
414    self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
415    sock_id = self.sock_diag._EmptyInetDiagSockId()
416    sock_id.sport = self.port
417    states = 1 << tcp_test.TCP_SYN_RECV
418    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
419    children = self.sock_diag.Dump(req, NO_BYTECODE)
420
421    self.assertTrue(children)
422    for child, unused_args in children:
423      self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
424      self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
425                       child.id.dst)
426      self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
427                       child.id.src)
428
429
430class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
431
432  def setUp(self):
433    super(SockDestroyTcpTest, self).setUp()
434    self.netid = random.choice(self.tuns.keys())
435
436  def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
437    """Closes the socket and checks whether a RST is sent or not."""
438    if sock is not None:
439      self.assertIsNone(req, "Must specify sock or req, not both")
440      self.sock_diag.CloseSocketFromFd(sock)
441      self.assertRaisesErrno(EINVAL, sock.accept)
442    else:
443      self.assertIsNone(sock, "Must specify sock or req, not both")
444      self.sock_diag.CloseSocket(req)
445
446    if expect_reset:
447      desc, rst = self.RstPacket()
448      msg = "%s: expecting %s: " % (msg, desc)
449      self.ExpectPacketOn(self.netid, msg, rst)
450    else:
451      msg = "%s: " % msg
452      self.ExpectNoPacketsOn(self.netid, msg)
453
454    if sock is not None and do_close:
455      sock.close()
456
457  def CheckTcpReset(self, state, statename):
458    for version in [4, 5, 6]:
459      msg = "Closing incoming IPv%d %s socket" % (version, statename)
460      self.IncomingConnection(version, state, self.netid)
461      self.CheckRstOnClose(self.s, None, False, msg)
462      if state != tcp_test.TCP_LISTEN:
463        msg = "Closing accepted IPv%d %s socket" % (version, statename)
464        self.CheckRstOnClose(self.accepted, None, True, msg)
465
466  def testTcpResets(self):
467    """Checks that closing sockets in appropriate states sends a RST."""
468    self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
469    self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
470    self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
471
472  def testFinWait1Socket(self):
473    for version in [4, 5, 6]:
474      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
475
476      # Get the cookie so we can find this socket after we close it.
477      diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted)
478      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
479
480      # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN.
481      net_test.EnableFinWait(self.accepted)
482      self.accepted.close()
483      diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1
484      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
485      self.assertEquals(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
486      desc, fin = self.FinPacket()
487      self.ExpectPacketOn(self.netid, "Closing FIN_WAIT1 socket", fin)
488
489      # Destroy the socket and expect no RST.
490      self.CheckRstOnClose(None, diag_req, False, "Closing FIN_WAIT1 socket")
491      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
492
493      # The socket is still there in FIN_WAIT1: SOCK_DESTROY did nothing
494      # because userspace had already closed it.
495      self.assertEquals(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
496
497      # ACK the FIN so we don't trip over retransmits in future tests.
498      finversion = 4 if version == 5 else version
499      desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin)
500      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
501      self.ReceivePacketOn(self.netid, finack)
502
503      # See if we can find the resulting FIN_WAIT2 socket. This does not appear
504      # to work on 3.10.
505      if net_test.LINUX_VERSION >= (3, 18):
506        diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2
507        infos = self.sock_diag.Dump(diag_req, "")
508        self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2
509                            for diag_msg, attrs in infos),
510                        "Expected to find FIN_WAIT2 socket in %s" % infos)
511
512  def FindChildSockets(self, s):
513    """Finds the SYN_RECV child sockets of a given listening socket."""
514    d = self.sock_diag.FindSockDiagFromFd(self.s)
515    req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
516    req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED
517    req.id.cookie = "\x00" * 8
518
519    bad_bytecode = self.PackAndCheckBytecode(
520        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))])
521    self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode))
522
523    bytecode = self.PackAndCheckBytecode(
524        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))])
525    children = self.sock_diag.Dump(req, bytecode)
526    return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
527            for d, _ in children]
528
529  def CheckChildSocket(self, version, statename, parent_first):
530    state = getattr(tcp_test, statename)
531
532    self.IncomingConnection(version, state, self.netid)
533
534    d = self.sock_diag.FindSockDiagFromFd(self.s)
535    parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
536    children = self.FindChildSockets(self.s)
537    self.assertEquals(1, len(children))
538
539    is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED)
540    expected_state = tcp_test.TCP_ESTABLISHED if is_established else state
541
542    # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
543    # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
544    # Before 4.4, we can see those sockets in dumps, but we can't fetch
545    # or close them.
546    can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
547
548    for child in children:
549      if can_close_children:
550        diag_msg, attrs = self.sock_diag.GetSockInfo(child)
551        self.assertEquals(diag_msg.state, expected_state)
552        self.assertMarkIs(self.netid, attrs)
553      else:
554        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
555
556    def CloseParent(expect_reset):
557      msg = "Closing parent IPv%d %s socket %s child" % (
558          version, statename, "before" if parent_first else "after")
559      self.CheckRstOnClose(self.s, None, expect_reset, msg)
560      self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent)
561
562    def CheckChildrenClosed():
563      for child in children:
564        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
565
566    def CloseChildren():
567      for child in children:
568        msg = "Closing child IPv%d %s socket %s parent" % (
569            version, statename, "after" if parent_first else "before")
570        self.sock_diag.GetSockInfo(child)
571        self.CheckRstOnClose(None, child, is_established, msg)
572        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
573      CheckChildrenClosed()
574
575    if parent_first:
576      # Closing the parent will close child sockets, which will send a RST,
577      # iff they are already established.
578      CloseParent(is_established)
579      if is_established:
580        CheckChildrenClosed()
581      elif can_close_children:
582        CloseChildren()
583        CheckChildrenClosed()
584      self.s.close()
585    else:
586      if can_close_children:
587        CloseChildren()
588      CloseParent(False)
589      self.s.close()
590
591  def testChildSockets(self):
592    for version in [4, 5, 6]:
593      self.CheckChildSocket(version, "TCP_SYN_RECV", False)
594      self.CheckChildSocket(version, "TCP_SYN_RECV", True)
595      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False)
596      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True)
597
598  def testAcceptInterrupted(self):
599    """Tests that accept() is interrupted by SOCK_DESTROY."""
600    for version in [4, 5, 6]:
601      self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
602      self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
603      self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
604      self.assertRaisesErrno(EINVAL, self.s.accept)
605
606  def testReadInterrupted(self):
607    """Tests that read() is interrupted by SOCK_DESTROY."""
608    for version in [4, 5, 6]:
609      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
610      self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
611                                   ECONNABORTED)
612      self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
613
614  def testConnectInterrupted(self):
615    """Tests that connect() is interrupted by SOCK_DESTROY."""
616    for version in [4, 5, 6]:
617      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
618      s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
619      self.SelectInterface(s, self.netid, "mark")
620      if version == 5:
621        remoteaddr = "::ffff:" + self.GetRemoteAddress(4)
622        version = 4
623      else:
624        remoteaddr = self.GetRemoteAddress(version)
625      s.bind(("", 0))
626      _, sport = s.getsockname()[:2]
627      self.CloseDuringBlockingCall(
628          s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED)
629      desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
630                              remoteaddr, sport=sport, seq=None)
631      self.ExpectPacketOn(self.netid, desc, syn)
632      msg = "SOCK_DESTROY of socket in connect, expected no RST"
633      self.ExpectNoPacketsOn(self.netid, msg)
634
635
636class SockDestroyUdpTest(SockDiagBaseTest):
637
638  """Tests SOCK_DESTROY on UDP sockets.
639
640    Relevant kernel commits:
641      upstream net-next:
642        5d77dca net: diag: support SOCK_DESTROY for UDP sockets
643        f95bf34 net: diag: make udp_diag_destroy work for mapped addresses.
644  """
645
646  def testClosesUdpSockets(self):
647    self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM)
648    for _, socketpair in self.socketpairs.iteritems():
649      s1, s2 = socketpair
650
651      self.assertSocketConnected(s1)
652      self.sock_diag.CloseSocketFromFd(s1)
653      self.assertSocketClosed(s1)
654
655      self.assertSocketConnected(s2)
656      self.sock_diag.CloseSocketFromFd(s2)
657      self.assertSocketClosed(s2)
658
659  def BindToRandomPort(self, s, addr):
660    ATTEMPTS = 20
661    for i in xrange(20):
662      port = random.randrange(1024, 65535)
663      try:
664        s.bind((addr, port))
665        return port
666      except error, e:
667        if e.errno != EADDRINUSE:
668          raise e
669    raise ValueError("Could not find a free port on %s after %d attempts" %
670                     (addr, ATTEMPTS))
671
672  def testSocketAddressesAfterClose(self):
673    for version in 4, 5, 6:
674      netid = random.choice(self.NETIDS)
675      dst = self.GetRemoteAddress(version)
676      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
677      unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
678
679      # Closing a socket that was not explicitly bound (i.e., bound via
680      # connect(), not bind()) clears the source address and port.
681      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
682      self.SelectInterface(s, netid, "mark")
683      s.connect((dst, 53))
684      self.sock_diag.CloseSocketFromFd(s)
685      self.assertEqual((unspec, 0), s.getsockname()[:2])
686
687      # Closing a socket bound to an IP address leaves the address as is.
688      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
689      src = self.MyAddress(version, netid)
690      s.bind((src, 0))
691      s.connect((dst, 53))
692      port = s.getsockname()[1]
693      self.sock_diag.CloseSocketFromFd(s)
694      self.assertEqual((src, 0), s.getsockname()[:2])
695
696      # Closing a socket bound to a port leaves the port as is.
697      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
698      port = self.BindToRandomPort(s, "")
699      s.connect((dst, 53))
700      self.sock_diag.CloseSocketFromFd(s)
701      self.assertEqual((unspec, port), s.getsockname()[:2])
702
703      # Closing a socket bound to IP address and port leaves both as is.
704      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
705      src = self.MyAddress(version, netid)
706      port = self.BindToRandomPort(s, src)
707      self.sock_diag.CloseSocketFromFd(s)
708      self.assertEqual((src, port), s.getsockname()[:2])
709
710  def testReadInterrupted(self):
711    """Tests that read() is interrupted by SOCK_DESTROY."""
712    for version in [4, 5, 6]:
713      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
714      s = net_test.UDPSocket(family)
715      self.SelectInterface(s, random.choice(self.NETIDS), "mark")
716      addr = self.GetRemoteAddress(version)
717
718      # Check that reads on connected sockets are interrupted.
719      s.connect((addr, 53))
720      self.assertEquals(3, s.send("foo"))
721      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
722                                   ECONNABORTED)
723
724      # A destroyed socket is no longer connected, but still usable.
725      self.assertRaisesErrno(EDESTADDRREQ, s.send, "foo")
726      self.assertEquals(3, s.sendto("foo", (addr, 53)))
727
728      # Check that reads on unconnected sockets are also interrupted.
729      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
730                                   ECONNABORTED)
731
732class SockDestroyPermissionTest(SockDiagBaseTest):
733
734  def CheckPermissions(self, socktype):
735    s = socket(AF_INET6, socktype, 0)
736    self.SelectInterface(s, random.choice(self.NETIDS), "mark")
737    if socktype == SOCK_STREAM:
738      s.listen(1)
739      expectedstate = tcp_test.TCP_LISTEN
740    else:
741      s.connect((self.GetRemoteAddress(6), 53))
742      expectedstate = tcp_test.TCP_ESTABLISHED
743
744    with net_test.RunAsUid(12345):
745      self.assertRaisesErrno(
746          EPERM, self.sock_diag.CloseSocketFromFd, s)
747
748    self.sock_diag.CloseSocketFromFd(s)
749    self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s)
750
751
752  def testUdp(self):
753    self.CheckPermissions(SOCK_DGRAM)
754
755  def testTcp(self):
756    self.CheckPermissions(SOCK_STREAM)
757
758
759class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
760
761  """Tests SOCK_DIAG bytecode filters that use marks.
762
763    Relevant kernel commits:
764      upstream net-next:
765        627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks.
766        a52e95a net: diag: allow socket bytecode filters to match socket marks
767        d545cac net: inet: diag: expose the socket mark to privileged processes.
768  """
769
770  IPPROTO_SCTP = 132
771
772  def FilterEstablishedSockets(self, mark, mask):
773    instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))]
774    bytecode = self.sock_diag.PackBytecode(instructions)
775    return self.sock_diag.DumpAllInetSockets(
776        IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED))
777
778  def assertSamePorts(self, ports, diag_msgs):
779    expected = sorted(ports)
780    actual = sorted([msg[0].id.sport for msg in diag_msgs])
781    self.assertEquals(expected, actual)
782
783  def SockInfoMatchesSocket(self, s, info):
784    try:
785      self.assertSockInfoMatchesSocket(s, info)
786      return True
787    except AssertionError:
788      return False
789
790  @staticmethod
791  def SocketDescription(s):
792    return "%s -> %s" % (str(s.getsockname()), str(s.getpeername()))
793
794  def assertFoundSockets(self, infos, sockets):
795    matches = {}
796    for s in sockets:
797      match = None
798      for info in infos:
799        if self.SockInfoMatchesSocket(s, info):
800          if match:
801            self.fail("Socket %s matched both %s and %s" %
802                      (self.SocketDescription(s), match, info))
803          matches[s] = info
804      self.assertTrue(s in matches, "Did not find socket %s in dump" %
805                      self.SocketDescription(s))
806
807    for i in infos:
808       if i not in matches.values():
809         self.fail("Too many sockets in dump, first unexpected: %s" % str(i))
810
811  def testMarkBytecode(self):
812    family, addr = random.choice([
813        (AF_INET, "127.0.0.1"),
814        (AF_INET6, "::1"),
815        (AF_INET6, "::ffff:127.0.0.1")])
816    s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
817    s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234)
818    s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235)
819
820    infos = self.FilterEstablishedSockets(0x1234, 0xffff)
821    self.assertFoundSockets(infos, [s1])
822
823    infos = self.FilterEstablishedSockets(0x1234, 0xfffe)
824    self.assertFoundSockets(infos, [s1, s2])
825
826    infos = self.FilterEstablishedSockets(0x1235, 0xffff)
827    self.assertFoundSockets(infos, [s2])
828
829    infos = self.FilterEstablishedSockets(0x0, 0x0)
830    self.assertFoundSockets(infos, [s1, s2])
831
832    infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00)
833    self.assertEquals(0, len(infos))
834
835    with net_test.RunAsUid(12345):
836        self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets,
837                               0xfff0000, 0xf0fed00)
838
839  @staticmethod
840  def SetRandomMark(s):
841    # Python doesn't like marks that don't fit into a signed int.
842    mark = random.randrange(0, 2**31 - 1)
843    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark)
844    return mark
845
846  def assertSocketMarkIs(self, s, mark):
847    diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
848    self.assertMarkIs(mark, attrs)
849    with net_test.RunAsUid(12345):
850      diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
851      self.assertMarkIs(None, attrs)
852
853  def testMarkInAttributes(self):
854    testcases = [(AF_INET, "127.0.0.1"),
855                 (AF_INET6, "::1"),
856                 (AF_INET6, "::ffff:127.0.0.1")]
857    for family, addr in testcases:
858      # TCP listen sockets.
859      server = socket(family, SOCK_STREAM, 0)
860      server.bind((addr, 0))
861      port = server.getsockname()[1]
862      server.listen(1)  # Or the socket won't be in the hashtables.
863      server_mark = self.SetRandomMark(server)
864      self.assertSocketMarkIs(server, server_mark)
865
866      # TCP client sockets.
867      client = socket(family, SOCK_STREAM, 0)
868      client_mark = self.SetRandomMark(client)
869      client.connect((addr, port))
870      self.assertSocketMarkIs(client, client_mark)
871
872      # TCP server sockets.
873      accepted, _ = server.accept()
874      self.assertSocketMarkIs(accepted, server_mark)
875
876      accepted_mark = self.SetRandomMark(accepted)
877      self.assertSocketMarkIs(accepted, accepted_mark)
878      self.assertSocketMarkIs(server, server_mark)
879
880      server.close()
881      client.close()
882
883      # Other TCP states are tested in SockDestroyTcpTest.
884
885      # UDP sockets.
886      s = socket(family, SOCK_DGRAM, 0)
887      mark = self.SetRandomMark(s)
888      s.connect(("", 53))
889      self.assertSocketMarkIs(s, mark)
890      s.close()
891
892      # Basic test for SCTP. sctp_diag was only added in 4.7.
893      if net_test.LINUX_VERSION >= (4, 7, 0):
894        s = socket(family, SOCK_STREAM, self.IPPROTO_SCTP)
895        s.bind((addr, 0))
896        s.listen(1)
897        mark = self.SetRandomMark(s)
898        self.assertSocketMarkIs(s, mark)
899        sockets = self.sock_diag.DumpAllInetSockets(self.IPPROTO_SCTP,
900                                                    NO_BYTECODE)
901        self.assertEqual(1, len(sockets))
902        self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None))
903        s.close()
904
905
906if __name__ == "__main__":
907  unittest.main()
908