1#!/usr/bin/python
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import fcntl
18import os
19import random
20import re
21from socket import *  # pylint: disable=wildcard-import
22import struct
23import sys
24import unittest
25
26from scapy import all as scapy
27
28import csocket
29
30# TODO: Move these to csocket.py.
31SOL_IPV6 = 41
32IP_RECVERR = 11
33IPV6_RECVERR = 25
34IP_TRANSPARENT = 19
35IPV6_TRANSPARENT = 75
36IPV6_TCLASS = 67
37IPV6_FLOWLABEL_MGR = 32
38IPV6_FLOWINFO_SEND = 33
39
40SO_BINDTODEVICE = 25
41SO_MARK = 36
42SO_PROTOCOL = 38
43SO_DOMAIN = 39
44SO_COOKIE = 57
45
46ETH_P_IP = 0x0800
47ETH_P_IPV6 = 0x86dd
48
49IPPROTO_GRE = 47
50
51SIOCSIFHWADDR = 0x8924
52
53IPV6_FL_A_GET = 0
54IPV6_FL_A_PUT = 1
55IPV6_FL_A_RENEW = 1
56
57IPV6_FL_F_CREATE = 1
58IPV6_FL_F_EXCL = 2
59
60IPV6_FL_S_NONE = 0
61IPV6_FL_S_EXCL = 1
62IPV6_FL_S_ANY = 255
63
64IFNAMSIZ = 16
65
66IPV4_PING = "\x08\x00\x00\x00\x0a\xce\x00\x03"
67IPV6_PING = "\x80\x00\x00\x00\x0a\xce\x00\x03"
68
69IPV4_ADDR = "8.8.8.8"
70IPV4_ADDR2 = "8.8.4.4"
71IPV6_ADDR = "2001:4860:4860::8888"
72IPV6_ADDR2 = "2001:4860:4860::8844"
73
74IPV6_SEQ_DGRAM_HEADER = ("  sl  "
75                         "local_address                         "
76                         "remote_address                        "
77                         "st tx_queue rx_queue tr tm->when retrnsmt"
78                         "   uid  timeout inode ref pointer drops\n")
79
80UDP_HDR_LEN = 8
81
82# Arbitrary packet payload.
83UDP_PAYLOAD = str(scapy.DNS(rd=1,
84                            id=random.randint(0, 65535),
85                            qd=scapy.DNSQR(qname="wWW.GoOGle.CoM",
86                                           qtype="AAAA")))
87
88# Unix group to use if we want to open sockets as non-root.
89AID_INET = 3003
90
91# Kernel log verbosity levels.
92KERN_INFO = 6
93
94LINUX_VERSION = csocket.LinuxVersion()
95LINUX_ANY_VERSION = (0, 0)
96
97def GetWildcardAddress(version):
98  return {4: "0.0.0.0", 6: "::"}[version]
99
100def GetIpHdrLength(version):
101  return {4: 20, 6: 40}[version]
102
103def GetAddressFamily(version):
104  return {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
105
106
107def AddressLengthBits(version):
108  return {4: 32, 6: 128}[version]
109
110def GetAddressVersion(address):
111  if ":" not in address:
112    return 4
113  if address.startswith("::ffff"):
114    return 5
115  return 6
116
117def SetSocketTos(s, tos):
118  level = {AF_INET: SOL_IP, AF_INET6: SOL_IPV6}[s.family]
119  option = {AF_INET: IP_TOS, AF_INET6: IPV6_TCLASS}[s.family]
120  s.setsockopt(level, option, tos)
121
122
123def SetNonBlocking(fd):
124  flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
125  fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
126
127
128# Convenience functions to create sockets.
129def Socket(family, sock_type, protocol):
130  s = socket(family, sock_type, protocol)
131  csocket.SetSocketTimeout(s, 5000)
132  return s
133
134
135def PingSocket(family):
136  proto = {AF_INET: IPPROTO_ICMP, AF_INET6: IPPROTO_ICMPV6}[family]
137  return Socket(family, SOCK_DGRAM, proto)
138
139
140def IPv4PingSocket():
141  return PingSocket(AF_INET)
142
143
144def IPv6PingSocket():
145  return PingSocket(AF_INET6)
146
147
148def TCPSocket(family):
149  s = Socket(family, SOCK_STREAM, IPPROTO_TCP)
150  SetNonBlocking(s.fileno())
151  return s
152
153
154def IPv4TCPSocket():
155  return TCPSocket(AF_INET)
156
157
158def IPv6TCPSocket():
159  return TCPSocket(AF_INET6)
160
161
162def UDPSocket(family):
163  return Socket(family, SOCK_DGRAM, IPPROTO_UDP)
164
165
166def RawGRESocket(family):
167  s = Socket(family, SOCK_RAW, IPPROTO_GRE)
168  return s
169
170
171def BindRandomPort(version, sock):
172  addr = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
173  sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
174  sock.bind((addr, 0))
175  if sock.getsockopt(SOL_SOCKET, SO_PROTOCOL) == IPPROTO_TCP:
176    sock.listen(100)
177  port = sock.getsockname()[1]
178  return port
179
180
181def EnableFinWait(sock):
182  # Disabling SO_LINGER causes sockets to go into FIN_WAIT on close().
183  sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 0, 0))
184
185
186def DisableFinWait(sock):
187  # Enabling SO_LINGER with a timeout of zero causes close() to send RST.
188  sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 1, 0))
189
190
191def CreateSocketPair(family, socktype, addr):
192  clientsock = socket(family, socktype, 0)
193  listensock = socket(family, socktype, 0)
194  listensock.bind((addr, 0))
195  addr = listensock.getsockname()
196  if socktype == SOCK_STREAM:
197    listensock.listen(1)
198  clientsock.connect(listensock.getsockname())
199  if socktype == SOCK_STREAM:
200    acceptedsock, _ = listensock.accept()
201    DisableFinWait(clientsock)
202    DisableFinWait(acceptedsock)
203    listensock.close()
204  else:
205    listensock.connect(clientsock.getsockname())
206    acceptedsock = listensock
207  return clientsock, acceptedsock
208
209
210def GetInterfaceIndex(ifname):
211  s = UDPSocket(AF_INET)
212  ifr = struct.pack("%dsi" % IFNAMSIZ, ifname, 0)
213  ifr = fcntl.ioctl(s, scapy.SIOCGIFINDEX, ifr)
214  return struct.unpack("%dsi" % IFNAMSIZ, ifr)[1]
215
216
217def SetInterfaceHWAddr(ifname, hwaddr):
218  s = UDPSocket(AF_INET)
219  hwaddr = hwaddr.replace(":", "")
220  hwaddr = hwaddr.decode("hex")
221  if len(hwaddr) != 6:
222    raise ValueError("Unknown hardware address length %d" % len(hwaddr))
223  ifr = struct.pack("%dsH6s" % IFNAMSIZ, ifname, scapy.ARPHDR_ETHER, hwaddr)
224  fcntl.ioctl(s, SIOCSIFHWADDR, ifr)
225
226
227def SetInterfaceState(ifname, up):
228  s = UDPSocket(AF_INET)
229  ifr = struct.pack("%dsH" % IFNAMSIZ, ifname, 0)
230  ifr = fcntl.ioctl(s, scapy.SIOCGIFFLAGS, ifr)
231  _, flags = struct.unpack("%dsH" % IFNAMSIZ, ifr)
232  if up:
233    flags |= scapy.IFF_UP
234  else:
235    flags &= ~scapy.IFF_UP
236  ifr = struct.pack("%dsH" % IFNAMSIZ, ifname, flags)
237  ifr = fcntl.ioctl(s, scapy.SIOCSIFFLAGS, ifr)
238
239
240def SetInterfaceUp(ifname):
241  return SetInterfaceState(ifname, True)
242
243
244def SetInterfaceDown(ifname):
245  return SetInterfaceState(ifname, False)
246
247
248def CanonicalizeIPv6Address(addr):
249  return inet_ntop(AF_INET6, inet_pton(AF_INET6, addr))
250
251
252def FormatProcAddress(unformatted):
253  groups = []
254  for i in range(0, len(unformatted), 4):
255    groups.append(unformatted[i:i+4])
256  formatted = ":".join(groups)
257  # Compress the address.
258  address = CanonicalizeIPv6Address(formatted)
259  return address
260
261
262def FormatSockStatAddress(address):
263  if ":" in address:
264    family = AF_INET6
265  else:
266    family = AF_INET
267  binary = inet_pton(family, address)
268  out = ""
269  for i in range(0, len(binary), 4):
270    out += "%08X" % struct.unpack("=L", binary[i:i+4])
271  return out
272
273
274def GetLinkAddress(ifname, linklocal):
275  addresses = open("/proc/net/if_inet6").readlines()
276  for address in addresses:
277    address = [s for s in address.strip().split(" ") if s]
278    if address[5] == ifname:
279      if (linklocal and address[0].startswith("fe80")
280          or not linklocal and not address[0].startswith("fe80")):
281        # Convert the address from raw hex to something with colons in it.
282        return FormatProcAddress(address[0])
283  return None
284
285
286def GetDefaultRoute(version=6):
287  if version == 6:
288    routes = open("/proc/net/ipv6_route").readlines()
289    for route in routes:
290      route = [s for s in route.strip().split(" ") if s]
291      if (route[0] == "00000000000000000000000000000000" and route[1] == "00"
292          # Routes in non-default tables end up in /proc/net/ipv6_route!!!
293          and route[9] != "lo" and not route[9].startswith("nettest")):
294        return FormatProcAddress(route[4]), route[9]
295    raise ValueError("No IPv6 default route found")
296  elif version == 4:
297    routes = open("/proc/net/route").readlines()
298    for route in routes:
299      route = [s for s in route.strip().split("\t") if s]
300      if route[1] == "00000000" and route[7] == "00000000":
301        gw, iface = route[2], route[0]
302        gw = inet_ntop(AF_INET, gw.decode("hex")[::-1])
303        return gw, iface
304    raise ValueError("No IPv4 default route found")
305  else:
306    raise ValueError("Don't know about IPv%s" % version)
307
308
309def GetDefaultRouteInterface():
310  unused_gw, iface = GetDefaultRoute()
311  return iface
312
313
314def MakeFlowLabelOption(addr, label):
315  # struct in6_flowlabel_req {
316  #         struct in6_addr flr_dst;
317  #         __be32  flr_label;
318  #         __u8    flr_action;
319  #         __u8    flr_share;
320  #         __u16   flr_flags;
321  #         __u16   flr_expires;
322  #         __u16   flr_linger;
323  #         __u32   __flr_pad;
324  #         /* Options in format of IPV6_PKTOPTIONS */
325  # };
326  fmt = "16sIBBHHH4s"
327  assert struct.calcsize(fmt) == 32
328  addr = inet_pton(AF_INET6, addr)
329  assert len(addr) == 16
330  label = htonl(label & 0xfffff)
331  action = IPV6_FL_A_GET
332  share = IPV6_FL_S_ANY
333  flags = IPV6_FL_F_CREATE
334  pad = "\x00" * 4
335  return struct.pack(fmt, addr, label, action, share, flags, 0, 0, pad)
336
337
338def SetFlowLabel(s, addr, label):
339  opt = MakeFlowLabelOption(addr, label)
340  s.setsockopt(SOL_IPV6, IPV6_FLOWLABEL_MGR, opt)
341  # Caller also needs to do s.setsockopt(SOL_IPV6, IPV6_FLOWINFO_SEND, 1).
342
343
344def RunIptablesCommand(version, args):
345  iptables = {4: "iptables", 6: "ip6tables"}[version]
346  iptables_path = "/sbin/" + iptables
347  if not os.access(iptables_path, os.X_OK):
348    iptables_path = "/system/bin/" + iptables
349  return os.spawnvp(os.P_WAIT, iptables_path, [iptables_path] + args.split(" "))
350
351# Determine network configuration.
352try:
353  GetDefaultRoute(version=4)
354  HAVE_IPV4 = True
355except ValueError:
356  HAVE_IPV4 = False
357
358try:
359  GetDefaultRoute(version=6)
360  HAVE_IPV6 = True
361except ValueError:
362  HAVE_IPV6 = False
363
364class RunAsUidGid(object):
365  """Context guard to run a code block as a given UID."""
366
367  def __init__(self, uid, gid):
368    self.uid = uid
369    self.gid = gid
370
371  def __enter__(self):
372    if self.gid:
373      self.saved_gid = os.getgid()
374      os.setgid(self.gid)
375    if self.uid:
376      self.saved_uids = os.getresuid()
377      self.saved_groups = os.getgroups()
378      os.setgroups(self.saved_groups + [AID_INET])
379      os.setresuid(self.uid, self.uid, self.saved_uids[0])
380
381  def __exit__(self, unused_type, unused_value, unused_traceback):
382    if self.uid:
383      os.setresuid(*self.saved_uids)
384      os.setgroups(self.saved_groups)
385    if self.gid:
386      os.setgid(self.saved_gid)
387
388class RunAsUid(RunAsUidGid):
389  """Context guard to run a code block as a given GID and UID."""
390
391  def __init__(self, uid):
392    RunAsUidGid.__init__(self, uid, 0)
393
394class NetworkTest(unittest.TestCase):
395
396  def assertRaisesRegex(self, *args, **kwargs):
397    if sys.version_info.major < 3:
398      return self.assertRaisesRegexp(*args, **kwargs)
399    else:
400      return super().assertRaisesRegex(*args, **kwargs)
401
402  def assertRaisesErrno(self, err_num, f=None, *args):
403    """Test that the system returns an errno error.
404
405    This works similarly to unittest.TestCase.assertRaises. You can call it as
406    an assertion, or use it as a context manager.
407    e.g.
408        self.assertRaisesErrno(errno.ENOENT, do_things, arg1, arg2)
409    or
410        with self.assertRaisesErrno(errno.ENOENT):
411          do_things(arg1, arg2)
412
413    Args:
414      err_num: an errno constant
415      f: (optional) A callable that should result in error
416      *args: arguments passed to f
417    """
418    msg = os.strerror(err_num)
419    if f is None:
420      return self.assertRaisesRegex(EnvironmentError, msg)
421    else:
422      self.assertRaisesRegex(EnvironmentError, msg, f, *args)
423
424  def ReadProcNetSocket(self, protocol):
425    # Read file.
426    filename = "/proc/net/%s" % protocol
427    lines = open(filename).readlines()
428
429    # Possibly check, and strip, header.
430    if protocol in ["icmp6", "raw6", "udp6"]:
431      self.assertEqual(IPV6_SEQ_DGRAM_HEADER, lines[0])
432    lines = lines[1:]
433
434    # Check contents.
435    if protocol.endswith("6"):
436      addrlen = 32
437    else:
438      addrlen = 8
439
440    if protocol.startswith("tcp"):
441      # Real sockets have 5 extra numbers, timewait sockets have none.
442      end_regexp = "(| +[0-9]+ [0-9]+ [0-9]+ [0-9]+ -?[0-9]+)$"
443    elif re.match("icmp|udp|raw", protocol):
444      # Drops.
445      end_regexp = " +([0-9]+) *$"
446    else:
447      raise ValueError("Don't know how to parse %s" % filename)
448
449    regexp = re.compile(r" *(\d+): "                    # bucket
450                        "([0-9A-F]{%d}:[0-9A-F]{4}) "   # srcaddr, port
451                        "([0-9A-F]{%d}:[0-9A-F]{4}) "   # dstaddr, port
452                        "([0-9A-F][0-9A-F]) "           # state
453                        "([0-9A-F]{8}:[0-9A-F]{8}) "    # mem
454                        "([0-9A-F]{2}:[0-9A-F]{8}) "    # ?
455                        "([0-9A-F]{8}) +"               # ?
456                        "([0-9]+) +"                    # uid
457                        "([0-9]+) +"                    # timeout
458                        "([0-9]+) +"                    # inode
459                        "([0-9]+) +"                    # refcnt
460                        "([0-9a-f]+)"                   # sp
461                        "%s"                            # icmp has spaces
462                        % (addrlen, addrlen, end_regexp))
463    # Return a list of lists with only source / dest addresses for now.
464    # TODO: consider returning a dict or namedtuple instead.
465    out = []
466    for line in lines:
467      m = regexp.match(line)
468      if m is None:
469        raise ValueError("Failed match on [%s]" % line)
470      (_, src, dst, state, mem,
471       _, _, uid, _, _, refcnt, _, extra) = m.groups()
472      out.append([src, dst, state, mem, uid, refcnt, extra])
473    return out
474
475  @staticmethod
476  def GetConsoleLogLevel():
477    return int(open("/proc/sys/kernel/printk").readline().split()[0])
478
479  @staticmethod
480  def SetConsoleLogLevel(level):
481    return open("/proc/sys/kernel/printk", "w").write("%s\n" % level)
482
483
484if __name__ == "__main__":
485  unittest.main()
486