1#!/usr/bin/python 2# 3# Copyright 2015 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Partial Python implementation of sock_diag functionality.""" 18 19# pylint: disable=g-bad-todo 20 21import errno 22import os 23from socket import * # pylint: disable=wildcard-import 24import struct 25 26import csocket 27import cstruct 28import net_test 29import netlink 30 31### Base netlink constants. See include/uapi/linux/netlink.h. 32NETLINK_SOCK_DIAG = 4 33 34### sock_diag constants. See include/uapi/linux/sock_diag.h. 35# Message types. 36SOCK_DIAG_BY_FAMILY = 20 37SOCK_DESTROY = 21 38 39### inet_diag_constants. See include/uapi/linux/inet_diag.h 40# Message types. 41TCPDIAG_GETSOCK = 18 42 43# Request attributes. 44INET_DIAG_REQ_BYTECODE = 1 45 46# Extensions. 47INET_DIAG_NONE = 0 48INET_DIAG_MEMINFO = 1 49INET_DIAG_INFO = 2 50INET_DIAG_VEGASINFO = 3 51INET_DIAG_CONG = 4 52INET_DIAG_TOS = 5 53INET_DIAG_TCLASS = 6 54INET_DIAG_SKMEMINFO = 7 55INET_DIAG_SHUTDOWN = 8 56INET_DIAG_DCTCPINFO = 9 57INET_DIAG_DCTCPINFO = 9 58INET_DIAG_PROTOCOL = 10 59INET_DIAG_SKV6ONLY = 11 60INET_DIAG_LOCALS = 12 61INET_DIAG_PEERS = 13 62INET_DIAG_PAD = 14 63INET_DIAG_MARK = 15 64 65# Bytecode operations. 66INET_DIAG_BC_NOP = 0 67INET_DIAG_BC_JMP = 1 68INET_DIAG_BC_S_GE = 2 69INET_DIAG_BC_S_LE = 3 70INET_DIAG_BC_D_GE = 4 71INET_DIAG_BC_D_LE = 5 72INET_DIAG_BC_AUTO = 6 73INET_DIAG_BC_S_COND = 7 74INET_DIAG_BC_D_COND = 8 75INET_DIAG_BC_DEV_COND = 9 76INET_DIAG_BC_MARK_COND = 10 77 78# Data structure formats. 79# These aren't constants, they're classes. So, pylint: disable=invalid-name 80InetDiagSockId = cstruct.Struct( 81 "InetDiagSockId", "!HH16s16sI8s", "sport dport src dst iface cookie") 82InetDiagReqV2 = cstruct.Struct( 83 "InetDiagReqV2", "=BBBxIS", "family protocol ext states id", 84 [InetDiagSockId]) 85InetDiagMsg = cstruct.Struct( 86 "InetDiagMsg", "=BBBBSLLLLL", 87 "family state timer retrans id expires rqueue wqueue uid inode", 88 [InetDiagSockId]) 89InetDiagMeminfo = cstruct.Struct( 90 "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem") 91InetDiagBcOp = cstruct.Struct("InetDiagBcOp", "BBH", "code yes no") 92InetDiagHostcond = cstruct.Struct("InetDiagHostcond", "=BBxxi", 93 "family prefix_len port") 94InetDiagMarkcond = cstruct.Struct("InetDiagMarkcond", "=II", "mark mask") 95 96SkMeminfo = cstruct.Struct( 97 "SkMeminfo", "=IIIIIIII", 98 "rmem_alloc rcvbuf wmem_alloc sndbuf fwd_alloc wmem_queued optmem backlog") 99TcpInfo = cstruct.Struct( 100 "TcpInfo", "=BBBBBBBxIIIIIIIIIIIIIIIIIIIIIIII", 101 "state ca_state retransmits probes backoff options wscale " 102 "rto ato snd_mss rcv_mss " 103 "unacked sacked lost retrans fackets " 104 "last_data_sent last_ack_sent last_data_recv last_ack_recv " 105 "pmtu rcv_ssthresh rtt rttvar snd_ssthresh snd_cwnd advmss reordering " 106 "rcv_rtt rcv_space " 107 "total_retrans") # As of linux 3.13, at least. 108 109TCP_TIME_WAIT = 6 110ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << TCP_TIME_WAIT) 111 112 113class SockDiag(netlink.NetlinkSocket): 114 115 FAMILY = NETLINK_SOCK_DIAG 116 NL_DEBUG = [] 117 118 def _Decode(self, command, msg, nla_type, nla_data): 119 """Decodes netlink attributes to Python types.""" 120 if msg.family == AF_INET or msg.family == AF_INET6: 121 if isinstance(msg, InetDiagReqV2): 122 prefix = "INET_DIAG_REQ" 123 else: 124 prefix = "INET_DIAG" 125 name = self._GetConstantName(__name__, nla_type, prefix) 126 else: 127 # Don't know what this is. Leave it as an integer. 128 name = nla_type 129 130 if name in ["INET_DIAG_SHUTDOWN", "INET_DIAG_TOS", "INET_DIAG_TCLASS", 131 "INET_DIAG_SKV6ONLY"]: 132 data = ord(nla_data) 133 elif name == "INET_DIAG_CONG": 134 data = nla_data.strip("\x00") 135 elif name == "INET_DIAG_MEMINFO": 136 data = InetDiagMeminfo(nla_data) 137 elif name == "INET_DIAG_INFO": 138 # TODO: Catch the exception and try something else if it's not TCP. 139 data = TcpInfo(nla_data) 140 elif name == "INET_DIAG_SKMEMINFO": 141 data = SkMeminfo(nla_data) 142 elif name == "INET_DIAG_MARK": 143 data = struct.unpack("=I", nla_data)[0] 144 elif name == "INET_DIAG_REQ_BYTECODE": 145 data = self.DecodeBytecode(nla_data) 146 elif name in ["INET_DIAG_LOCALS", "INET_DIAG_PEERS"]: 147 data = [] 148 while len(nla_data): 149 # The SCTP diag code always appears to copy sizeof(sockaddr_storage) 150 # bytes, but does so from a union sctp_addr which is at most as long 151 # as a sockaddr_in6. 152 addr, nla_data = cstruct.Read(nla_data, csocket.SockaddrStorage) 153 if addr.family == AF_INET: 154 addr = csocket.SockaddrIn(addr.Pack()) 155 elif addr.family == AF_INET6: 156 addr = csocket.SockaddrIn6(addr.Pack()) 157 data.append(addr) 158 else: 159 data = nla_data 160 161 return name, data 162 163 def MaybeDebugCommand(self, command, unused_flags, data): 164 name = self._GetConstantName(__name__, command, "SOCK_") 165 if "ALL" not in self.NL_DEBUG and "SOCK" not in self.NL_DEBUG: 166 return 167 parsed = self._ParseNLMsg(data, InetDiagReqV2) 168 print "%s %s" % (name, str(parsed)) 169 170 @staticmethod 171 def _EmptyInetDiagSockId(): 172 return InetDiagSockId(("\x00" * len(InetDiagSockId))) 173 174 @staticmethod 175 def PackBytecode(instructions): 176 """Compiles instructions to inet_diag bytecode. 177 178 The input is a list of (INET_DIAG_BC_xxx, yes, no, arg) tuples, where yes 179 and no are relative jump offsets measured in instructions. The yes branch 180 is taken if the instruction matches. 181 182 To accept, jump 1 past the last instruction. To reject, jump 2 past the 183 last instruction. 184 185 The target of a no jump is only valid if it is reachable by following 186 only yes jumps from the first instruction - see inet_diag_bc_audit and 187 valid_cc. This means that if cond1 and cond2 are two mutually exclusive 188 filter terms, it is not possible to implement cond1 OR cond2 using: 189 190 ... 191 cond1 2 1 arg 192 cond2 1 2 arg 193 accept 194 reject 195 196 but only using: 197 198 ... 199 cond1 1 2 arg 200 jmp 1 2 201 cond2 1 2 arg 202 accept 203 reject 204 205 The jmp instruction ignores yes and always jumps to no, but yes must be 1 206 or the bytecode won't validate. It doesn't have to be jmp - any instruction 207 that is guaranteed not to match on real data will do. 208 209 Args: 210 instructions: list of instruction tuples 211 212 Returns: 213 A string, the raw bytecode. 214 """ 215 args = [] 216 positions = [0] 217 218 for op, yes, no, arg in instructions: 219 220 if yes <= 0 or no <= 0: 221 raise ValueError("Jumps must be > 0") 222 223 if op in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]: 224 arg = "" 225 elif op in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE, 226 INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]: 227 arg = "\x00\x00" + struct.pack("=H", arg) 228 elif op in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]: 229 addr, prefixlen, port = arg 230 family = AF_INET6 if ":" in addr else AF_INET 231 addr = inet_pton(family, addr) 232 arg = InetDiagHostcond((family, prefixlen, port)).Pack() + addr 233 elif op == INET_DIAG_BC_MARK_COND: 234 if isinstance(arg, tuple): 235 mark, mask = arg 236 else: 237 mark, mask = arg, 0xffffffff 238 arg = InetDiagMarkcond((mark, mask)).Pack() 239 else: 240 raise ValueError("Unsupported opcode %d" % op) 241 242 args.append(arg) 243 length = len(InetDiagBcOp) + len(arg) 244 positions.append(positions[-1] + length) 245 246 # Reject label. 247 positions.append(positions[-1] + 4) # Why 4? Because the kernel uses 4. 248 assert len(args) == len(instructions) == len(positions) - 2 249 250 # print positions 251 252 packed = "" 253 for i, (op, yes, no, arg) in enumerate(instructions): 254 yes = positions[i + yes] - positions[i] 255 no = positions[i + no] - positions[i] 256 instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i] 257 #print "%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no, 258 # arg, instruction.encode("hex")) 259 packed += instruction 260 #print 261 262 return packed 263 264 @staticmethod 265 def DecodeBytecode(bytecode): 266 instructions = [] 267 try: 268 while bytecode: 269 op, rest = cstruct.Read(bytecode, InetDiagBcOp) 270 271 if op.code in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]: 272 arg = None 273 elif op.code in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE, 274 INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]: 275 op, rest = cstruct.Read(rest, InetDiagBcOp) 276 arg = op.no 277 elif op.code in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]: 278 cond, rest = cstruct.Read(rest, InetDiagHostcond) 279 if cond.family == 0: 280 arg = (None, cond.prefix_len, cond.port) 281 else: 282 addrlen = 4 if cond.family == AF_INET else 16 283 addr, rest = rest[:addrlen], rest[addrlen:] 284 addr = inet_ntop(cond.family, addr) 285 arg = (addr, cond.prefix_len, cond.port) 286 elif op.code == INET_DIAG_BC_DEV_COND: 287 attrlen = struct.calcsize("=I") 288 attr, rest = rest[:attrlen], rest[attrlen:] 289 arg = struct.unpack("=I", attr) 290 elif op.code == INET_DIAG_BC_MARK_COND: 291 arg, rest = cstruct.Read(rest, InetDiagMarkcond) 292 else: 293 raise ValueError("Unknown opcode %d" % op.code) 294 instructions.append((op, arg)) 295 bytecode = rest 296 297 return instructions 298 except (TypeError, ValueError): 299 return "???" 300 301 def Dump(self, diag_req, bytecode): 302 if bytecode: 303 bytecode = self._NlAttr(INET_DIAG_REQ_BYTECODE, bytecode) 304 305 out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode) 306 return out 307 308 def DumpAllInetSockets(self, protocol, bytecode, sock_id=None, ext=0, 309 states=ALL_NON_TIME_WAIT): 310 """Dumps IPv4 or IPv6 sockets matching the specified parameters.""" 311 # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it 312 # results in ENOENT. 313 if sock_id is None: 314 sock_id = self._EmptyInetDiagSockId() 315 316 sockets = [] 317 for family in [AF_INET, AF_INET6]: 318 diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id)) 319 sockets += self.Dump(diag_req, bytecode) 320 321 return sockets 322 323 @staticmethod 324 def GetRawAddress(family, addr): 325 """Fetches the source address from an InetDiagMsg.""" 326 addrlen = {AF_INET:4, AF_INET6: 16}[family] 327 return inet_ntop(family, addr[:addrlen]) 328 329 @staticmethod 330 def GetSourceAddress(diag_msg): 331 """Fetches the source address from an InetDiagMsg.""" 332 return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.src) 333 334 @staticmethod 335 def GetDestinationAddress(diag_msg): 336 """Fetches the source address from an InetDiagMsg.""" 337 return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.dst) 338 339 @staticmethod 340 def RawAddress(addr): 341 """Converts an IP address string to binary format.""" 342 family = AF_INET6 if ":" in addr else AF_INET 343 return inet_pton(family, addr) 344 345 @staticmethod 346 def PaddedAddress(addr): 347 """Converts an IP address string to binary format for InetDiagSockId.""" 348 padded = SockDiag.RawAddress(addr) 349 if len(padded) < 16: 350 padded += "\x00" * (16 - len(padded)) 351 return padded 352 353 @staticmethod 354 def DiagReqFromSocket(s): 355 """Creates an InetDiagReqV2 that matches the specified socket.""" 356 family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) 357 protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL) 358 if net_test.LINUX_VERSION >= (3, 8): 359 iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE, 360 net_test.IFNAMSIZ) 361 iface = GetInterfaceIndex(iface) if iface else 0 362 else: 363 iface = 0 364 src, sport = s.getsockname()[:2] 365 try: 366 dst, dport = s.getpeername()[:2] 367 except error, e: 368 if e.errno == errno.ENOTCONN: 369 dport = 0 370 dst = "::" if family == AF_INET6 else "0.0.0.0" 371 else: 372 raise e 373 src = SockDiag.PaddedAddress(src) 374 dst = SockDiag.PaddedAddress(dst) 375 sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8)) 376 return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id)) 377 378 def FindSockInfoFromFd(self, s): 379 """Gets a diag_msg and attrs from the kernel for the specified socket.""" 380 req = self.DiagReqFromSocket(s) 381 # The kernel doesn't use idiag_src and idiag_dst when dumping sockets, it 382 # only uses them when targeting a specific socket with a cookie. Check the 383 # the inode number to ensure we don't mistakenly match another socket on 384 # the same port but with a different IP address. 385 inode = os.fstat(s.fileno()).st_ino 386 results = self.Dump(req, "") 387 if len(results) == 0: 388 raise ValueError("Dump of %s returned no sockets" % req) 389 for diag_msg, attrs in results: 390 if diag_msg.inode == inode: 391 return diag_msg, attrs 392 raise ValueError("Dump of %s did not contain inode %d" % (req, inode)) 393 394 def FindSockDiagFromFd(self, s): 395 """Gets an InetDiagMsg from the kernel for the specified socket.""" 396 return self.FindSockInfoFromFd(s)[0] 397 398 def GetSockInfo(self, req): 399 """Gets a diag_msg and attrs from the kernel for the specified request.""" 400 self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST) 401 return self._GetMsg(InetDiagMsg) 402 403 @staticmethod 404 def DiagReqFromDiagMsg(d, protocol): 405 """Constructs a diag_req from a diag_msg the kernel has given us.""" 406 return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id)) 407 408 def CloseSocket(self, req): 409 self._SendNlRequest(SOCK_DESTROY, req.Pack(), 410 netlink.NLM_F_REQUEST | netlink.NLM_F_ACK) 411 412 def CloseSocketFromFd(self, s): 413 diag_msg, attrs = self.FindSockInfoFromFd(s) 414 protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL) 415 req = self.DiagReqFromDiagMsg(diag_msg, protocol) 416 return self.CloseSocket(req) 417 418 419if __name__ == "__main__": 420 n = SockDiag() 421 n.DEBUG = True 422 bytecode = "" 423 sock_id = n._EmptyInetDiagSockId() 424 sock_id.dport = 443 425 ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1) 426 states = 0xffffffff 427 diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, "", 428 sock_id=sock_id, ext=ext, states=states) 429 print diag_msgs 430