1#!/usr/bin/python 2# 3# Copyright 2015 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Partial Python implementation of sock_diag functionality.""" 18 19# pylint: disable=g-bad-todo 20 21import errno 22from socket import * # pylint: disable=wildcard-import 23import struct 24 25import cstruct 26import net_test 27import netlink 28 29### Base netlink constants. See include/uapi/linux/netlink.h. 30NETLINK_SOCK_DIAG = 4 31 32### sock_diag constants. See include/uapi/linux/sock_diag.h. 33# Message types. 34SOCK_DIAG_BY_FAMILY = 20 35SOCK_DESTROY = 21 36 37### inet_diag_constants. See include/uapi/linux/inet_diag.h 38# Message types. 39TCPDIAG_GETSOCK = 18 40 41# Request attributes. 42INET_DIAG_REQ_BYTECODE = 1 43 44# Extensions. 45INET_DIAG_NONE = 0 46INET_DIAG_MEMINFO = 1 47INET_DIAG_INFO = 2 48INET_DIAG_VEGASINFO = 3 49INET_DIAG_CONG = 4 50INET_DIAG_TOS = 5 51INET_DIAG_TCLASS = 6 52INET_DIAG_SKMEMINFO = 7 53INET_DIAG_SHUTDOWN = 8 54INET_DIAG_DCTCPINFO = 9 55 56# Bytecode operations. 57INET_DIAG_BC_NOP = 0 58INET_DIAG_BC_JMP = 1 59INET_DIAG_BC_S_GE = 2 60INET_DIAG_BC_S_LE = 3 61INET_DIAG_BC_D_GE = 4 62INET_DIAG_BC_D_LE = 5 63INET_DIAG_BC_AUTO = 6 64INET_DIAG_BC_S_COND = 7 65INET_DIAG_BC_D_COND = 8 66 67# Data structure formats. 68# These aren't constants, they're classes. So, pylint: disable=invalid-name 69InetDiagSockId = cstruct.Struct( 70 "InetDiagSockId", "!HH16s16sI8s", "sport dport src dst iface cookie") 71InetDiagReqV2 = cstruct.Struct( 72 "InetDiagReqV2", "=BBBxIS", "family protocol ext states id", 73 [InetDiagSockId]) 74InetDiagMsg = cstruct.Struct( 75 "InetDiagMsg", "=BBBBSLLLLL", 76 "family state timer retrans id expires rqueue wqueue uid inode", 77 [InetDiagSockId]) 78InetDiagMeminfo = cstruct.Struct( 79 "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem") 80InetDiagBcOp = cstruct.Struct("InetDiagBcOp", "BBH", "code yes no") 81InetDiagHostcond = cstruct.Struct("InetDiagHostcond", "=BBxxi", 82 "family prefix_len port") 83 84SkMeminfo = cstruct.Struct( 85 "SkMeminfo", "=IIIIIIII", 86 "rmem_alloc rcvbuf wmem_alloc sndbuf fwd_alloc wmem_queued optmem backlog") 87TcpInfo = cstruct.Struct( 88 "TcpInfo", "=BBBBBBBxIIIIIIIIIIIIIIIIIIIIIIII", 89 "state ca_state retransmits probes backoff options wscale " 90 "rto ato snd_mss rcv_mss " 91 "unacked sacked lost retrans fackets " 92 "last_data_sent last_ack_sent last_data_recv last_ack_recv " 93 "pmtu rcv_ssthresh rtt rttvar snd_ssthresh snd_cwnd advmss reordering " 94 "rcv_rtt rcv_space " 95 "total_retrans") # As of linux 3.13, at least. 96 97TCP_TIME_WAIT = 6 98ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << TCP_TIME_WAIT) 99 100 101class SockDiag(netlink.NetlinkSocket): 102 103 FAMILY = NETLINK_SOCK_DIAG 104 NL_DEBUG = [] 105 106 def _Decode(self, command, msg, nla_type, nla_data): 107 """Decodes netlink attributes to Python types.""" 108 if msg.family == AF_INET or msg.family == AF_INET6: 109 name = self._GetConstantName(__name__, nla_type, "INET_DIAG") 110 else: 111 # Don't know what this is. Leave it as an integer. 112 name = nla_type 113 114 if name in ["INET_DIAG_SHUTDOWN", "INET_DIAG_TOS", "INET_DIAG_TCLASS"]: 115 data = ord(nla_data) 116 elif name == "INET_DIAG_CONG": 117 data = nla_data.strip("\x00") 118 elif name == "INET_DIAG_MEMINFO": 119 data = InetDiagMeminfo(nla_data) 120 elif name == "INET_DIAG_INFO": 121 # TODO: Catch the exception and try something else if it's not TCP. 122 data = TcpInfo(nla_data) 123 elif name == "INET_DIAG_SKMEMINFO": 124 data = SkMeminfo(nla_data) 125 else: 126 data = nla_data 127 128 return name, data 129 130 def MaybeDebugCommand(self, command, data): 131 name = self._GetConstantName(__name__, command, "SOCK_") 132 if "ALL" not in self.NL_DEBUG and "SOCK" not in self.NL_DEBUG: 133 return 134 parsed = self._ParseNLMsg(data, InetDiagReqV2) 135 print "%s %s" % (name, str(parsed)) 136 137 @staticmethod 138 def _EmptyInetDiagSockId(): 139 return InetDiagSockId(("\x00" * len(InetDiagSockId))) 140 141 def PackBytecode(self, instructions): 142 """Compiles instructions to inet_diag bytecode. 143 144 The input is a list of (INET_DIAG_BC_xxx, yes, no, arg) tuples, where yes 145 and no are relative jump offsets measured in instructions. The yes branch 146 is taken if the instruction matches. 147 148 To accept, jump 1 past the last instruction. To reject, jump 2 past the 149 last instruction. 150 151 The target of a no jump is only valid if it is reachable by following 152 only yes jumps from the first instruction - see inet_diag_bc_audit and 153 valid_cc. This means that if cond1 and cond2 are two mutually exclusive 154 filter terms, it is not possible to implement cond1 OR cond2 using: 155 156 ... 157 cond1 2 1 arg 158 cond2 1 2 arg 159 accept 160 reject 161 162 but only using: 163 164 ... 165 cond1 1 2 arg 166 jmp 1 2 167 cond2 1 2 arg 168 accept 169 reject 170 171 The jmp instruction ignores yes and always jumps to no, but yes must be 1 172 or the bytecode won't validate. It doesn't have to be jmp - any instruction 173 that is guaranteed not to match on real data will do. 174 175 Args: 176 instructions: list of instruction tuples 177 178 Returns: 179 A string, the raw bytecode. 180 """ 181 args = [] 182 positions = [0] 183 184 for op, yes, no, arg in instructions: 185 186 if yes <= 0 or no <= 0: 187 raise ValueError("Jumps must be > 0") 188 189 if op in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]: 190 arg = "" 191 elif op in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE, 192 INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]: 193 arg = "\x00\x00" + struct.pack("=H", arg) 194 elif op in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]: 195 addr, prefixlen, port = arg 196 family = AF_INET6 if ":" in addr else AF_INET 197 addr = inet_pton(family, addr) 198 arg = InetDiagHostcond((family, prefixlen, port)).Pack() + addr 199 else: 200 raise ValueError("Unsupported opcode %d" % op) 201 202 args.append(arg) 203 length = len(InetDiagBcOp) + len(arg) 204 positions.append(positions[-1] + length) 205 206 # Reject label. 207 positions.append(positions[-1] + 4) # Why 4? Because the kernel uses 4. 208 assert len(args) == len(instructions) == len(positions) - 2 209 210 # print positions 211 212 packed = "" 213 for i, (op, yes, no, arg) in enumerate(instructions): 214 yes = positions[i + yes] - positions[i] 215 no = positions[i + no] - positions[i] 216 instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i] 217 #print "%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no, 218 # arg, instruction.encode("hex")) 219 packed += instruction 220 #print 221 222 return packed 223 224 def Dump(self, diag_req, bytecode=""): 225 out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode) 226 return out 227 228 def DumpAllInetSockets(self, protocol, bytecode, sock_id=None, ext=0, 229 states=ALL_NON_TIME_WAIT): 230 """Dumps IPv4 or IPv6 sockets matching the specified parameters.""" 231 # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it 232 # results in ENOENT. 233 if sock_id is None: 234 sock_id = self._EmptyInetDiagSockId() 235 236 if bytecode: 237 bytecode = self._NlAttr(INET_DIAG_REQ_BYTECODE, bytecode) 238 239 sockets = [] 240 for family in [AF_INET, AF_INET6]: 241 diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id)) 242 sockets += self.Dump(diag_req, bytecode) 243 244 return sockets 245 246 @staticmethod 247 def GetRawAddress(family, addr): 248 """Fetches the source address from an InetDiagMsg.""" 249 addrlen = {AF_INET:4, AF_INET6: 16}[family] 250 return inet_ntop(family, addr[:addrlen]) 251 252 @staticmethod 253 def GetSourceAddress(diag_msg): 254 """Fetches the source address from an InetDiagMsg.""" 255 return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.src) 256 257 @staticmethod 258 def GetDestinationAddress(diag_msg): 259 """Fetches the source address from an InetDiagMsg.""" 260 return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.dst) 261 262 @staticmethod 263 def RawAddress(addr): 264 """Converts an IP address string to binary format.""" 265 family = AF_INET6 if ":" in addr else AF_INET 266 return inet_pton(family, addr) 267 268 @staticmethod 269 def PaddedAddress(addr): 270 """Converts an IP address string to binary format for InetDiagSockId.""" 271 padded = SockDiag.RawAddress(addr) 272 if len(padded) < 16: 273 padded += "\x00" * (16 - len(padded)) 274 return padded 275 276 @staticmethod 277 def DiagReqFromSocket(s): 278 """Creates an InetDiagReqV2 that matches the specified socket.""" 279 family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) 280 protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL) 281 if net_test.LINUX_VERSION >= (3, 8): 282 iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE, 283 net_test.IFNAMSIZ) 284 iface = GetInterfaceIndex(iface) if iface else 0 285 else: 286 iface = 0 287 src, sport = s.getsockname()[:2] 288 try: 289 dst, dport = s.getpeername()[:2] 290 except error, e: 291 if e.errno == errno.ENOTCONN: 292 dport = 0 293 dst = "::" if family == AF_INET6 else "0.0.0.0" 294 else: 295 raise e 296 src = SockDiag.PaddedAddress(src) 297 dst = SockDiag.PaddedAddress(dst) 298 sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8)) 299 return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id)) 300 301 def FindSockDiagFromReq(self, req): 302 for diag_msg, attrs in self.Dump(req): 303 return diag_msg 304 raise ValueError("Dump of %s returned no sockets" % req) 305 306 def FindSockDiagFromFd(self, s): 307 """Gets an InetDiagMsg from the kernel for the specified socket.""" 308 req = self.DiagReqFromSocket(s) 309 return self.FindSockDiagFromReq(req) 310 311 def GetSockDiag(self, req): 312 """Gets an InetDiagMsg from the kernel for the specified request.""" 313 self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST) 314 return self._GetMsg(InetDiagMsg)[0] 315 316 @staticmethod 317 def DiagReqFromDiagMsg(d, protocol): 318 """Constructs a diag_req from a diag_msg the kernel has given us.""" 319 return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id)) 320 321 def CloseSocket(self, req): 322 self._SendNlRequest(SOCK_DESTROY, req.Pack(), 323 netlink.NLM_F_REQUEST | netlink.NLM_F_ACK) 324 325 def CloseSocketFromFd(self, s): 326 diag_msg = self.FindSockDiagFromFd(s) 327 protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL) 328 req = self.DiagReqFromDiagMsg(diag_msg, protocol) 329 return self.CloseSocket(req) 330 331 332if __name__ == "__main__": 333 n = SockDiag() 334 n.DEBUG = True 335 bytecode = "" 336 sock_id = n._EmptyInetDiagSockId() 337 sock_id.dport = 443 338 ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1) 339 states = 0xffffffff 340 diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, "", 341 sock_id=sock_id, ext=ext, states=states) 342 print diag_msgs 343