1#!/usr/bin/python 2# 3# Copyright 2015 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import 18from errno import * # pylint: disable=wildcard-import 19import os 20import random 21import re 22from socket import * # pylint: disable=wildcard-import 23import threading 24import time 25import unittest 26 27import multinetwork_base 28import net_test 29import netlink 30import packets 31import sock_diag 32import tcp_test 33 34 35NUM_SOCKETS = 30 36NO_BYTECODE = "" 37 38 39class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest): 40 """Basic tests for SOCK_DIAG functionality. 41 42 Relevant kernel commits: 43 android-3.4: 44 ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields 45 99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 46 47 android-3.10: 48 3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields 49 f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 50 51 android-3.18: 52 e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 53 54 android-4.4: 55 525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 56 """ 57 @staticmethod 58 def _CreateLotsOfSockets(socktype): 59 # Dict mapping (addr, sport, dport) tuples to socketpairs. 60 socketpairs = {} 61 for _ in xrange(NUM_SOCKETS): 62 family, addr = random.choice([ 63 (AF_INET, "127.0.0.1"), 64 (AF_INET6, "::1"), 65 (AF_INET6, "::ffff:127.0.0.1")]) 66 socketpair = net_test.CreateSocketPair(family, socktype, addr) 67 sport, dport = (socketpair[0].getsockname()[1], 68 socketpair[1].getsockname()[1]) 69 socketpairs[(addr, sport, dport)] = socketpair 70 return socketpairs 71 72 def assertSocketClosed(self, sock): 73 self.assertRaisesErrno(ENOTCONN, sock.getpeername) 74 75 def assertSocketConnected(self, sock): 76 sock.getpeername() # No errors? Socket is alive and connected. 77 78 def assertSocketsClosed(self, socketpair): 79 for sock in socketpair: 80 self.assertSocketClosed(sock) 81 82 def assertMarkIs(self, mark, attrs): 83 self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None)) 84 85 def assertSockInfoMatchesSocket(self, s, info): 86 diag_msg, attrs = info 87 family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) 88 self.assertEqual(diag_msg.family, family) 89 90 src, sport = s.getsockname()[0:2] 91 self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src)) 92 self.assertEqual(diag_msg.id.sport, sport) 93 94 if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]: 95 dst, dport = s.getpeername()[0:2] 96 self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst)) 97 self.assertEqual(diag_msg.id.dport, dport) 98 else: 99 self.assertRaisesErrno(ENOTCONN, s.getpeername) 100 101 mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK) 102 self.assertMarkIs(mark, attrs) 103 104 def PackAndCheckBytecode(self, instructions): 105 bytecode = self.sock_diag.PackBytecode(instructions) 106 decoded = self.sock_diag.DecodeBytecode(bytecode) 107 self.assertEquals(len(instructions), len(decoded)) 108 self.assertFalse("???" in decoded) 109 return bytecode 110 111 def CloseDuringBlockingCall(self, sock, call, expected_errno): 112 thread = SocketExceptionThread(sock, call) 113 thread.start() 114 time.sleep(0.1) 115 self.sock_diag.CloseSocketFromFd(sock) 116 thread.join(1) 117 self.assertFalse(thread.is_alive()) 118 self.assertIsNotNone(thread.exception) 119 self.assertTrue(isinstance(thread.exception, IOError), 120 "Expected IOError, got %s" % thread.exception) 121 self.assertEqual(expected_errno, thread.exception.errno) 122 self.assertSocketClosed(sock) 123 124 def setUp(self): 125 super(SockDiagBaseTest, self).setUp() 126 self.sock_diag = sock_diag.SockDiag() 127 self.socketpairs = {} 128 129 def tearDown(self): 130 for socketpair in self.socketpairs.values(): 131 for s in socketpair: 132 s.close() 133 super(SockDiagBaseTest, self).tearDown() 134 135 136class SockDiagTest(SockDiagBaseTest): 137 138 def testFindsMappedSockets(self): 139 """Tests that inet_diag_find_one_icsk can find mapped sockets.""" 140 socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, 141 "::ffff:127.0.0.1") 142 for sock in socketpair: 143 diag_msg = self.sock_diag.FindSockDiagFromFd(sock) 144 diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 145 self.sock_diag.GetSockInfo(diag_req) 146 # No errors? Good. 147 148 def testFindsAllMySockets(self): 149 """Tests that basic socket dumping works.""" 150 self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) 151 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE) 152 self.assertGreaterEqual(len(sockets), NUM_SOCKETS) 153 154 # Find the cookies for all of our sockets. 155 cookies = {} 156 for diag_msg, unused_attrs in sockets: 157 addr = self.sock_diag.GetSourceAddress(diag_msg) 158 sport = diag_msg.id.sport 159 dport = diag_msg.id.dport 160 if (addr, sport, dport) in self.socketpairs: 161 cookies[(addr, sport, dport)] = diag_msg.id.cookie 162 elif (addr, dport, sport) in self.socketpairs: 163 cookies[(addr, sport, dport)] = diag_msg.id.cookie 164 165 # Did we find all the cookies? 166 self.assertEquals(2 * NUM_SOCKETS, len(cookies)) 167 168 socketpairs = self.socketpairs.values() 169 random.shuffle(socketpairs) 170 for socketpair in socketpairs: 171 for sock in socketpair: 172 # Check that we can find a diag_msg by scanning a dump. 173 self.assertSockInfoMatchesSocket( 174 sock, 175 self.sock_diag.FindSockInfoFromFd(sock)) 176 cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie 177 178 # Check that we can find a diag_msg once we know the cookie. 179 req = self.sock_diag.DiagReqFromSocket(sock) 180 req.id.cookie = cookie 181 info = self.sock_diag.GetSockInfo(req) 182 self.assertSockInfoMatchesSocket(sock, info) 183 184 def testBytecodeCompilation(self): 185 # pylint: disable=bad-whitespace 186 instructions = [ 187 (sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0 188 (sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8 189 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16 190 (sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44 191 (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)), # 48 192 (sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64 193 (sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72 194 # 76 acc 195 # 80 rej 196 ] 197 # pylint: enable=bad-whitespace 198 bytecode = self.PackAndCheckBytecode(instructions) 199 expected = ( 200 "0208500000000000" 201 "050848000000ffff" 202 "071c20000a800000ffffffff00000000000000000000000000000001" 203 "01041c00" 204 "0718200002200000ffffffff7f000001" 205 "0508100000006566" 206 "00040400" 207 ) 208 states = 1 << tcp_test.TCP_ESTABLISHED 209 self.assertMultiLineEqual(expected, bytecode.encode("hex")) 210 self.assertEquals(76, len(bytecode)) 211 self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) 212 filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode, 213 states=states) 214 allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE, 215 states=states) 216 self.assertItemsEqual(allsockets, filteredsockets) 217 218 # Pick a few sockets in hash table order, and check that the bytecode we 219 # compiled selects them properly. 220 for socketpair in self.socketpairs.values()[:20]: 221 for s in socketpair: 222 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 223 instructions = [ 224 (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport), 225 (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport), 226 (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport), 227 (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport), 228 ] 229 bytecode = self.PackAndCheckBytecode(instructions) 230 self.assertEquals(32, len(bytecode)) 231 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode) 232 self.assertEquals(1, len(sockets)) 233 234 # TODO: why doesn't comparing the cstructs work? 235 self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack()) 236 237 def testCrossFamilyBytecode(self): 238 """Checks for a cross-family bug in inet_diag_hostcond matching. 239 240 Relevant kernel commits: 241 android-3.4: 242 f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run() 243 """ 244 # TODO: this is only here because the test fails if there are any open 245 # sockets other than the ones it creates itself. Make the bytecode more 246 # specific and remove it. 247 states = 1 << tcp_test.TCP_ESTABLISHED 248 self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, "", 249 states=states)) 250 251 unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1") 252 unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1") 253 254 bytecode4 = self.PackAndCheckBytecode([ 255 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))]) 256 bytecode6 = self.PackAndCheckBytecode([ 257 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))]) 258 259 # IPv4/v6 filters must never match IPv6/IPv4 sockets... 260 v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4, 261 states=states) 262 self.assertTrue(v4socks) 263 self.assertTrue(all(d.family == AF_INET for d, _ in v4socks)) 264 265 v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6, 266 states=states) 267 self.assertTrue(v6socks) 268 self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks)) 269 270 # Except for mapped addresses, which match both IPv4 and IPv6. 271 pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, 272 "::ffff:127.0.0.1") 273 diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5] 274 v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, 275 bytecode4, 276 states=states)] 277 v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, 278 bytecode6, 279 states=states)] 280 self.assertTrue(all(d in v4socks for d in diag_msgs)) 281 self.assertTrue(all(d in v6socks for d in diag_msgs)) 282 283 def testPortComparisonValidation(self): 284 """Checks for a bug in validating port comparison bytecode. 285 286 Relevant kernel commits: 287 android-3.4: 288 5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads 289 """ 290 bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8)) 291 self.assertEquals("???", 292 self.sock_diag.DecodeBytecode(bytecode)) 293 self.assertRaisesErrno( 294 EINVAL, 295 self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack()) 296 297 def testNonSockDiagCommand(self): 298 def DiagDump(code): 299 sock_id = self.sock_diag._EmptyInetDiagSockId() 300 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff, 301 sock_id)) 302 self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "") 303 304 op = sock_diag.SOCK_DIAG_BY_FAMILY 305 DiagDump(op) # No errors? Good. 306 self.assertRaisesErrno(EINVAL, DiagDump, op + 17) 307 308 309class SockDestroyTest(SockDiagBaseTest): 310 """Tests that SOCK_DESTROY works correctly. 311 312 Relevant kernel commits: 313 net-next: 314 b613f56 net: diag: split inet_diag_dump_one_icsk into two 315 64be0ae net: diag: Add the ability to destroy a socket. 316 6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets. 317 c1e64e2 net: diag: Support destroying TCP sockets. 318 2010b93 net: tcp: deal with listen sockets properly in tcp_abort. 319 320 android-3.4: 321 d48ec88 net: diag: split inet_diag_dump_one_icsk into two 322 2438189 net: diag: Add the ability to destroy a socket. 323 7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets. 324 44047b2 net: diag: Support destroying TCP sockets. 325 200dae7 net: tcp: deal with listen sockets properly in tcp_abort. 326 327 android-3.10: 328 9eaff90 net: diag: split inet_diag_dump_one_icsk into two 329 d60326c net: diag: Add the ability to destroy a socket. 330 3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets. 331 529dfc6 net: diag: Support destroying TCP sockets. 332 9c712fe net: tcp: deal with listen sockets properly in tcp_abort. 333 334 android-3.18: 335 100263d net: diag: split inet_diag_dump_one_icsk into two 336 194c5f3 net: diag: Add the ability to destroy a socket. 337 8387ea2 net: diag: Support SOCK_DESTROY for inet sockets. 338 b80585a net: diag: Support destroying TCP sockets. 339 476c6ce net: tcp: deal with listen sockets properly in tcp_abort. 340 341 android-4.1: 342 56eebf8 net: diag: split inet_diag_dump_one_icsk into two 343 fb486c9 net: diag: Add the ability to destroy a socket. 344 0c02b7e net: diag: Support SOCK_DESTROY for inet sockets. 345 67c71d8 net: diag: Support destroying TCP sockets. 346 a76e0ec net: tcp: deal with listen sockets properly in tcp_abort. 347 e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 348 349 android-4.4: 350 76c83a9 net: diag: split inet_diag_dump_one_icsk into two 351 f7cf791 net: diag: Add the ability to destroy a socket. 352 1c42248 net: diag: Support SOCK_DESTROY for inet sockets. 353 c9e8440d net: diag: Support destroying TCP sockets. 354 3d9502c tcp: diag: add support for request sockets to tcp_abort() 355 001cf75 net: tcp: deal with listen sockets properly in tcp_abort. 356 """ 357 358 def testClosesSockets(self): 359 self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) 360 for _, socketpair in self.socketpairs.iteritems(): 361 # Close one of the sockets. 362 # This will send a RST that will close the other side as well. 363 s = random.choice(socketpair) 364 if random.randrange(0, 2) == 1: 365 self.sock_diag.CloseSocketFromFd(s) 366 else: 367 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 368 369 # Get the cookie wrong and ensure that we get an error and the socket 370 # is not closed. 371 real_cookie = diag_msg.id.cookie 372 diag_msg.id.cookie = os.urandom(len(real_cookie)) 373 req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 374 self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req) 375 self.assertSocketConnected(s) 376 377 # Now close it with the correct cookie. 378 req.id.cookie = real_cookie 379 self.sock_diag.CloseSocket(req) 380 381 # Check that both sockets in the pair are closed. 382 self.assertSocketsClosed(socketpair) 383 384 # TODO: 385 # Test that killing unix sockets returns EOPNOTSUPP. 386 387 388class SocketExceptionThread(threading.Thread): 389 390 def __init__(self, sock, operation): 391 self.exception = None 392 super(SocketExceptionThread, self).__init__() 393 self.daemon = True 394 self.sock = sock 395 self.operation = operation 396 397 def run(self): 398 try: 399 self.operation(self.sock) 400 except IOError, e: 401 self.exception = e 402 403 404class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 405 406 def testIpv4MappedSynRecvSocket(self): 407 """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets. 408 409 Relevant kernel commits: 410 android-3.4: 411 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state 412 """ 413 netid = random.choice(self.tuns.keys()) 414 self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid) 415 sock_id = self.sock_diag._EmptyInetDiagSockId() 416 sock_id.sport = self.port 417 states = 1 << tcp_test.TCP_SYN_RECV 418 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id)) 419 children = self.sock_diag.Dump(req, NO_BYTECODE) 420 421 self.assertTrue(children) 422 for child, unused_args in children: 423 self.assertEqual(tcp_test.TCP_SYN_RECV, child.state) 424 self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr), 425 child.id.dst) 426 self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr), 427 child.id.src) 428 429 430class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 431 432 def setUp(self): 433 super(SockDestroyTcpTest, self).setUp() 434 self.netid = random.choice(self.tuns.keys()) 435 436 def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True): 437 """Closes the socket and checks whether a RST is sent or not.""" 438 if sock is not None: 439 self.assertIsNone(req, "Must specify sock or req, not both") 440 self.sock_diag.CloseSocketFromFd(sock) 441 self.assertRaisesErrno(EINVAL, sock.accept) 442 else: 443 self.assertIsNone(sock, "Must specify sock or req, not both") 444 self.sock_diag.CloseSocket(req) 445 446 if expect_reset: 447 desc, rst = self.RstPacket() 448 msg = "%s: expecting %s: " % (msg, desc) 449 self.ExpectPacketOn(self.netid, msg, rst) 450 else: 451 msg = "%s: " % msg 452 self.ExpectNoPacketsOn(self.netid, msg) 453 454 if sock is not None and do_close: 455 sock.close() 456 457 def CheckTcpReset(self, state, statename): 458 for version in [4, 5, 6]: 459 msg = "Closing incoming IPv%d %s socket" % (version, statename) 460 self.IncomingConnection(version, state, self.netid) 461 self.CheckRstOnClose(self.s, None, False, msg) 462 if state != tcp_test.TCP_LISTEN: 463 msg = "Closing accepted IPv%d %s socket" % (version, statename) 464 self.CheckRstOnClose(self.accepted, None, True, msg) 465 466 def testTcpResets(self): 467 """Checks that closing sockets in appropriate states sends a RST.""" 468 self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN") 469 self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED") 470 self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT") 471 472 def testFinWait1Socket(self): 473 for version in [4, 5, 6]: 474 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 475 476 # Get the cookie so we can find this socket after we close it. 477 diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted) 478 diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 479 480 # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN. 481 net_test.EnableFinWait(self.accepted) 482 self.accepted.close() 483 diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1 484 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 485 self.assertEquals(tcp_test.TCP_FIN_WAIT1, diag_msg.state) 486 desc, fin = self.FinPacket() 487 self.ExpectPacketOn(self.netid, "Closing FIN_WAIT1 socket", fin) 488 489 # Destroy the socket and expect no RST. 490 self.CheckRstOnClose(None, diag_req, False, "Closing FIN_WAIT1 socket") 491 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 492 493 # The socket is still there in FIN_WAIT1: SOCK_DESTROY did nothing 494 # because userspace had already closed it. 495 self.assertEquals(tcp_test.TCP_FIN_WAIT1, diag_msg.state) 496 497 # ACK the FIN so we don't trip over retransmits in future tests. 498 finversion = 4 if version == 5 else version 499 desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin) 500 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 501 self.ReceivePacketOn(self.netid, finack) 502 503 # See if we can find the resulting FIN_WAIT2 socket. This does not appear 504 # to work on 3.10. 505 if net_test.LINUX_VERSION >= (3, 18): 506 diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2 507 infos = self.sock_diag.Dump(diag_req, "") 508 self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2 509 for diag_msg, attrs in infos), 510 "Expected to find FIN_WAIT2 socket in %s" % infos) 511 512 def FindChildSockets(self, s): 513 """Finds the SYN_RECV child sockets of a given listening socket.""" 514 d = self.sock_diag.FindSockDiagFromFd(self.s) 515 req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 516 req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED 517 req.id.cookie = "\x00" * 8 518 519 bad_bytecode = self.PackAndCheckBytecode( 520 [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))]) 521 self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode)) 522 523 bytecode = self.PackAndCheckBytecode( 524 [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))]) 525 children = self.sock_diag.Dump(req, bytecode) 526 return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 527 for d, _ in children] 528 529 def CheckChildSocket(self, version, statename, parent_first): 530 state = getattr(tcp_test, statename) 531 532 self.IncomingConnection(version, state, self.netid) 533 534 d = self.sock_diag.FindSockDiagFromFd(self.s) 535 parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 536 children = self.FindChildSockets(self.s) 537 self.assertEquals(1, len(children)) 538 539 is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED) 540 expected_state = tcp_test.TCP_ESTABLISHED if is_established else state 541 542 # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the 543 # regular TCP hash tables, and inet_diag_find_one_icsk can find them. 544 # Before 4.4, we can see those sockets in dumps, but we can't fetch 545 # or close them. 546 can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4) 547 548 for child in children: 549 if can_close_children: 550 diag_msg, attrs = self.sock_diag.GetSockInfo(child) 551 self.assertEquals(diag_msg.state, expected_state) 552 self.assertMarkIs(self.netid, attrs) 553 else: 554 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) 555 556 def CloseParent(expect_reset): 557 msg = "Closing parent IPv%d %s socket %s child" % ( 558 version, statename, "before" if parent_first else "after") 559 self.CheckRstOnClose(self.s, None, expect_reset, msg) 560 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent) 561 562 def CheckChildrenClosed(): 563 for child in children: 564 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) 565 566 def CloseChildren(): 567 for child in children: 568 msg = "Closing child IPv%d %s socket %s parent" % ( 569 version, statename, "after" if parent_first else "before") 570 self.sock_diag.GetSockInfo(child) 571 self.CheckRstOnClose(None, child, is_established, msg) 572 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) 573 CheckChildrenClosed() 574 575 if parent_first: 576 # Closing the parent will close child sockets, which will send a RST, 577 # iff they are already established. 578 CloseParent(is_established) 579 if is_established: 580 CheckChildrenClosed() 581 elif can_close_children: 582 CloseChildren() 583 CheckChildrenClosed() 584 self.s.close() 585 else: 586 if can_close_children: 587 CloseChildren() 588 CloseParent(False) 589 self.s.close() 590 591 def testChildSockets(self): 592 for version in [4, 5, 6]: 593 self.CheckChildSocket(version, "TCP_SYN_RECV", False) 594 self.CheckChildSocket(version, "TCP_SYN_RECV", True) 595 self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False) 596 self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True) 597 598 def testAcceptInterrupted(self): 599 """Tests that accept() is interrupted by SOCK_DESTROY.""" 600 for version in [4, 5, 6]: 601 self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid) 602 self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL) 603 self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo") 604 self.assertRaisesErrno(EINVAL, self.s.accept) 605 606 def testReadInterrupted(self): 607 """Tests that read() is interrupted by SOCK_DESTROY.""" 608 for version in [4, 5, 6]: 609 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 610 self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096), 611 ECONNABORTED) 612 self.assertRaisesErrno(EPIPE, self.accepted.send, "foo") 613 614 def testConnectInterrupted(self): 615 """Tests that connect() is interrupted by SOCK_DESTROY.""" 616 for version in [4, 5, 6]: 617 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 618 s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) 619 self.SelectInterface(s, self.netid, "mark") 620 if version == 5: 621 remoteaddr = "::ffff:" + self.GetRemoteAddress(4) 622 version = 4 623 else: 624 remoteaddr = self.GetRemoteAddress(version) 625 s.bind(("", 0)) 626 _, sport = s.getsockname()[:2] 627 self.CloseDuringBlockingCall( 628 s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED) 629 desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid), 630 remoteaddr, sport=sport, seq=None) 631 self.ExpectPacketOn(self.netid, desc, syn) 632 msg = "SOCK_DESTROY of socket in connect, expected no RST" 633 self.ExpectNoPacketsOn(self.netid, msg) 634 635 636class SockDestroyUdpTest(SockDiagBaseTest): 637 638 """Tests SOCK_DESTROY on UDP sockets. 639 640 Relevant kernel commits: 641 upstream net-next: 642 5d77dca net: diag: support SOCK_DESTROY for UDP sockets 643 f95bf34 net: diag: make udp_diag_destroy work for mapped addresses. 644 """ 645 646 def testClosesUdpSockets(self): 647 self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM) 648 for _, socketpair in self.socketpairs.iteritems(): 649 s1, s2 = socketpair 650 651 self.assertSocketConnected(s1) 652 self.sock_diag.CloseSocketFromFd(s1) 653 self.assertSocketClosed(s1) 654 655 self.assertSocketConnected(s2) 656 self.sock_diag.CloseSocketFromFd(s2) 657 self.assertSocketClosed(s2) 658 659 def BindToRandomPort(self, s, addr): 660 ATTEMPTS = 20 661 for i in xrange(20): 662 port = random.randrange(1024, 65535) 663 try: 664 s.bind((addr, port)) 665 return port 666 except error, e: 667 if e.errno != EADDRINUSE: 668 raise e 669 raise ValueError("Could not find a free port on %s after %d attempts" % 670 (addr, ATTEMPTS)) 671 672 def testSocketAddressesAfterClose(self): 673 for version in 4, 5, 6: 674 netid = random.choice(self.NETIDS) 675 dst = self.GetRemoteAddress(version) 676 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 677 unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version] 678 679 # Closing a socket that was not explicitly bound (i.e., bound via 680 # connect(), not bind()) clears the source address and port. 681 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 682 self.SelectInterface(s, netid, "mark") 683 s.connect((dst, 53)) 684 self.sock_diag.CloseSocketFromFd(s) 685 self.assertEqual((unspec, 0), s.getsockname()[:2]) 686 687 # Closing a socket bound to an IP address leaves the address as is. 688 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 689 src = self.MyAddress(version, netid) 690 s.bind((src, 0)) 691 s.connect((dst, 53)) 692 port = s.getsockname()[1] 693 self.sock_diag.CloseSocketFromFd(s) 694 self.assertEqual((src, 0), s.getsockname()[:2]) 695 696 # Closing a socket bound to a port leaves the port as is. 697 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 698 port = self.BindToRandomPort(s, "") 699 s.connect((dst, 53)) 700 self.sock_diag.CloseSocketFromFd(s) 701 self.assertEqual((unspec, port), s.getsockname()[:2]) 702 703 # Closing a socket bound to IP address and port leaves both as is. 704 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 705 src = self.MyAddress(version, netid) 706 port = self.BindToRandomPort(s, src) 707 self.sock_diag.CloseSocketFromFd(s) 708 self.assertEqual((src, port), s.getsockname()[:2]) 709 710 def testReadInterrupted(self): 711 """Tests that read() is interrupted by SOCK_DESTROY.""" 712 for version in [4, 5, 6]: 713 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 714 s = net_test.UDPSocket(family) 715 self.SelectInterface(s, random.choice(self.NETIDS), "mark") 716 addr = self.GetRemoteAddress(version) 717 718 # Check that reads on connected sockets are interrupted. 719 s.connect((addr, 53)) 720 self.assertEquals(3, s.send("foo")) 721 self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), 722 ECONNABORTED) 723 724 # A destroyed socket is no longer connected, but still usable. 725 self.assertRaisesErrno(EDESTADDRREQ, s.send, "foo") 726 self.assertEquals(3, s.sendto("foo", (addr, 53))) 727 728 # Check that reads on unconnected sockets are also interrupted. 729 self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), 730 ECONNABORTED) 731 732class SockDestroyPermissionTest(SockDiagBaseTest): 733 734 def CheckPermissions(self, socktype): 735 s = socket(AF_INET6, socktype, 0) 736 self.SelectInterface(s, random.choice(self.NETIDS), "mark") 737 if socktype == SOCK_STREAM: 738 s.listen(1) 739 expectedstate = tcp_test.TCP_LISTEN 740 else: 741 s.connect((self.GetRemoteAddress(6), 53)) 742 expectedstate = tcp_test.TCP_ESTABLISHED 743 744 with net_test.RunAsUid(12345): 745 self.assertRaisesErrno( 746 EPERM, self.sock_diag.CloseSocketFromFd, s) 747 748 self.sock_diag.CloseSocketFromFd(s) 749 self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s) 750 751 752 def testUdp(self): 753 self.CheckPermissions(SOCK_DGRAM) 754 755 def testTcp(self): 756 self.CheckPermissions(SOCK_STREAM) 757 758 759class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 760 761 """Tests SOCK_DIAG bytecode filters that use marks. 762 763 Relevant kernel commits: 764 upstream net-next: 765 627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks. 766 a52e95a net: diag: allow socket bytecode filters to match socket marks 767 d545cac net: inet: diag: expose the socket mark to privileged processes. 768 """ 769 770 IPPROTO_SCTP = 132 771 772 def FilterEstablishedSockets(self, mark, mask): 773 instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))] 774 bytecode = self.sock_diag.PackBytecode(instructions) 775 return self.sock_diag.DumpAllInetSockets( 776 IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED)) 777 778 def assertSamePorts(self, ports, diag_msgs): 779 expected = sorted(ports) 780 actual = sorted([msg[0].id.sport for msg in diag_msgs]) 781 self.assertEquals(expected, actual) 782 783 def SockInfoMatchesSocket(self, s, info): 784 try: 785 self.assertSockInfoMatchesSocket(s, info) 786 return True 787 except AssertionError: 788 return False 789 790 @staticmethod 791 def SocketDescription(s): 792 return "%s -> %s" % (str(s.getsockname()), str(s.getpeername())) 793 794 def assertFoundSockets(self, infos, sockets): 795 matches = {} 796 for s in sockets: 797 match = None 798 for info in infos: 799 if self.SockInfoMatchesSocket(s, info): 800 if match: 801 self.fail("Socket %s matched both %s and %s" % 802 (self.SocketDescription(s), match, info)) 803 matches[s] = info 804 self.assertTrue(s in matches, "Did not find socket %s in dump" % 805 self.SocketDescription(s)) 806 807 for i in infos: 808 if i not in matches.values(): 809 self.fail("Too many sockets in dump, first unexpected: %s" % str(i)) 810 811 def testMarkBytecode(self): 812 family, addr = random.choice([ 813 (AF_INET, "127.0.0.1"), 814 (AF_INET6, "::1"), 815 (AF_INET6, "::ffff:127.0.0.1")]) 816 s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr) 817 s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234) 818 s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235) 819 820 infos = self.FilterEstablishedSockets(0x1234, 0xffff) 821 self.assertFoundSockets(infos, [s1]) 822 823 infos = self.FilterEstablishedSockets(0x1234, 0xfffe) 824 self.assertFoundSockets(infos, [s1, s2]) 825 826 infos = self.FilterEstablishedSockets(0x1235, 0xffff) 827 self.assertFoundSockets(infos, [s2]) 828 829 infos = self.FilterEstablishedSockets(0x0, 0x0) 830 self.assertFoundSockets(infos, [s1, s2]) 831 832 infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00) 833 self.assertEquals(0, len(infos)) 834 835 with net_test.RunAsUid(12345): 836 self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets, 837 0xfff0000, 0xf0fed00) 838 839 @staticmethod 840 def SetRandomMark(s): 841 # Python doesn't like marks that don't fit into a signed int. 842 mark = random.randrange(0, 2**31 - 1) 843 s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark) 844 return mark 845 846 def assertSocketMarkIs(self, s, mark): 847 diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s) 848 self.assertMarkIs(mark, attrs) 849 with net_test.RunAsUid(12345): 850 diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s) 851 self.assertMarkIs(None, attrs) 852 853 def testMarkInAttributes(self): 854 testcases = [(AF_INET, "127.0.0.1"), 855 (AF_INET6, "::1"), 856 (AF_INET6, "::ffff:127.0.0.1")] 857 for family, addr in testcases: 858 # TCP listen sockets. 859 server = socket(family, SOCK_STREAM, 0) 860 server.bind((addr, 0)) 861 port = server.getsockname()[1] 862 server.listen(1) # Or the socket won't be in the hashtables. 863 server_mark = self.SetRandomMark(server) 864 self.assertSocketMarkIs(server, server_mark) 865 866 # TCP client sockets. 867 client = socket(family, SOCK_STREAM, 0) 868 client_mark = self.SetRandomMark(client) 869 client.connect((addr, port)) 870 self.assertSocketMarkIs(client, client_mark) 871 872 # TCP server sockets. 873 accepted, _ = server.accept() 874 self.assertSocketMarkIs(accepted, server_mark) 875 876 accepted_mark = self.SetRandomMark(accepted) 877 self.assertSocketMarkIs(accepted, accepted_mark) 878 self.assertSocketMarkIs(server, server_mark) 879 880 server.close() 881 client.close() 882 883 # Other TCP states are tested in SockDestroyTcpTest. 884 885 # UDP sockets. 886 s = socket(family, SOCK_DGRAM, 0) 887 mark = self.SetRandomMark(s) 888 s.connect(("", 53)) 889 self.assertSocketMarkIs(s, mark) 890 s.close() 891 892 # Basic test for SCTP. sctp_diag was only added in 4.7. 893 if net_test.LINUX_VERSION >= (4, 7, 0): 894 s = socket(family, SOCK_STREAM, self.IPPROTO_SCTP) 895 s.bind((addr, 0)) 896 s.listen(1) 897 mark = self.SetRandomMark(s) 898 self.assertSocketMarkIs(s, mark) 899 sockets = self.sock_diag.DumpAllInetSockets(self.IPPROTO_SCTP, 900 NO_BYTECODE) 901 self.assertEqual(1, len(sockets)) 902 self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None)) 903 s.close() 904 905 906if __name__ == "__main__": 907 unittest.main() 908