1#!/usr/bin/python
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Base module for multinetwork tests."""
18
19import errno
20import fcntl
21import os
22import posix
23import random
24import re
25from socket import *  # pylint: disable=wildcard-import
26import struct
27
28from scapy import all as scapy
29
30import csocket
31import cstruct
32import iproute
33import net_test
34
35
36IFF_TUN = 1
37IFF_TAP = 2
38IFF_NO_PI = 0x1000
39TUNSETIFF = 0x400454ca
40
41SO_BINDTODEVICE = 25
42
43# Setsockopt values.
44IP_UNICAST_IF = 50
45IPV6_MULTICAST_IF = 17
46IPV6_UNICAST_IF = 76
47
48# Cmsg values.
49IP_TTL = 2
50IP_PKTINFO = 8
51IPV6_2292PKTOPTIONS = 6
52IPV6_FLOWINFO = 11
53IPV6_PKTINFO = 50
54IPV6_HOPLIMIT = 52  # Different from IPV6_UNICAST_HOPS, this is cmsg only.
55
56# Data structures.
57# These aren't constants, they're classes. So, pylint: disable=invalid-name
58InPktinfo = cstruct.Struct("in_pktinfo", "@i4s4s", "ifindex spec_dst addr")
59In6Pktinfo = cstruct.Struct("in6_pktinfo", "@16si", "addr ifindex")
60
61
62def HaveUidRouting():
63  """Checks whether the kernel supports UID routing."""
64  # Create a rule with the UID range selector. If the kernel doesn't understand
65  # the selector, it will create a rule with no selectors.
66  try:
67    iproute.IPRoute().UidRangeRule(6, True, 1000, 2000, 100, 10000)
68  except IOError:
69    return False
70
71  # Dump all the rules. If we find a rule using the UID range selector, then the
72  # kernel supports UID range routing.
73  rules = iproute.IPRoute().DumpRules(6)
74  result = any("FRA_UID_START" in attrs for rule, attrs in rules)
75
76  # Delete the rule.
77  iproute.IPRoute().UidRangeRule(6, False, 1000, 2000, 100, 10000)
78  return result
79
80AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table"
81
82HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL)
83HAVE_UID_ROUTING = HaveUidRouting()
84
85
86class UnexpectedPacketError(AssertionError):
87  pass
88
89
90def MakePktInfo(version, addr, ifindex):
91  family = {4: AF_INET, 6: AF_INET6}[version]
92  if not addr:
93    addr = {4: "0.0.0.0", 6: "::"}[version]
94  if addr:
95    addr = inet_pton(family, addr)
96  if version == 6:
97    return In6Pktinfo((addr, ifindex)).Pack()
98  else:
99    return InPktinfo((ifindex, addr, "\x00" * 4)).Pack()
100
101
102class MultiNetworkBaseTest(net_test.NetworkTest):
103
104  """Base class for all multinetwork tests.
105
106  This class does not contain any test code, but contains code to set up and
107  tear a multi-network environment using multiple tun interfaces. The
108  environment is designed to be similar to a real Android device in terms of
109  rules and routes, and supports IPv4 and IPv6.
110
111  Tests wishing to use this environment should inherit from this class and
112  ensure that any setupClass, tearDownClass, setUp, and tearDown methods they
113  implement also call the superclass versions.
114  """
115
116  # Must be between 1 and 256, since we put them in MAC addresses and IIDs.
117  NETIDS = [100, 150, 200, 250]
118
119  # Stores sysctl values to write back when the test completes.
120  saved_sysctls = {}
121
122  # Wether to output setup commands.
123  DEBUG = False
124
125  # The size of our UID ranges.
126  UID_RANGE_SIZE = 1000
127
128  # Rule priorities.
129  PRIORITY_UID = 100
130  PRIORITY_OIF = 200
131  PRIORITY_FWMARK = 300
132  PRIORITY_DEFAULT = 999
133  PRIORITY_UNREACHABLE = 1000
134
135  # For convenience.
136  IPV4_ADDR = net_test.IPV4_ADDR
137  IPV6_ADDR = net_test.IPV6_ADDR
138  IPV4_PING = net_test.IPV4_PING
139  IPV6_PING = net_test.IPV6_PING
140
141  @classmethod
142  def UidRangeForNetid(cls, netid):
143    return (
144        cls.UID_RANGE_SIZE * netid,
145        cls.UID_RANGE_SIZE * (netid + 1) - 1
146    )
147
148  @classmethod
149  def UidForNetid(cls, netid):
150    return random.randint(*cls.UidRangeForNetid(netid))
151
152  @classmethod
153  def _TableForNetid(cls, netid):
154    if cls.AUTOCONF_TABLE_OFFSET and netid in cls.ifindices:
155      return cls.ifindices[netid] + (-cls.AUTOCONF_TABLE_OFFSET)
156    else:
157      return netid
158
159  @staticmethod
160  def GetInterfaceName(netid):
161    return "nettest%d" % netid
162
163  @staticmethod
164  def RouterMacAddress(netid):
165    return "02:00:00:00:%02x:00" % netid
166
167  @staticmethod
168  def MyMacAddress(netid):
169    return "02:00:00:00:%02x:01" % netid
170
171  @staticmethod
172  def _RouterAddress(netid, version):
173    if version == 6:
174      return "fe80::%02x00" % netid
175    elif version == 4:
176      return "10.0.%d.1" % netid
177    else:
178      raise ValueError("Don't support IPv%s" % version)
179
180  @classmethod
181  def _MyIPv4Address(cls, netid):
182    return "10.0.%d.2" % netid
183
184  @classmethod
185  def _MyIPv6Address(cls, netid):
186    return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False)
187
188  @classmethod
189  def MyAddress(cls, version, netid):
190    return {4: cls._MyIPv4Address(netid),
191            6: cls._MyIPv6Address(netid)}[version]
192
193  @staticmethod
194  def IPv6Prefix(netid):
195    return "2001:db8:%02x::" % netid
196
197  @staticmethod
198  def GetRandomDestination(prefix):
199    if "." in prefix:
200      return prefix + "%d.%d" % (random.randint(0, 31), random.randint(0, 255))
201    else:
202      return prefix + "%x:%x" % (random.randint(0, 65535),
203                                 random.randint(0, 65535))
204
205  def GetProtocolFamily(self, version):
206    return {4: AF_INET, 6: AF_INET6}[version]
207
208  @classmethod
209  def CreateTunInterface(cls, netid):
210    iface = cls.GetInterfaceName(netid)
211    f = open("/dev/net/tun", "r+b")
212    ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI)
213    ifr += "\x00" * (40 - len(ifr))
214    fcntl.ioctl(f, TUNSETIFF, ifr)
215    # Give ourselves a predictable MAC address.
216    net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid))
217    # Disable DAD so we don't have to wait for it.
218    cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % iface, 0)
219    net_test.SetInterfaceUp(iface)
220    net_test.SetNonBlocking(f)
221    return f
222
223  @classmethod
224  def SendRA(cls, netid, retranstimer=None):
225    validity = 300                 # seconds
226    macaddr = cls.RouterMacAddress(netid)
227    lladdr = cls._RouterAddress(netid, 6)
228
229    if retranstimer is None:
230      # If no retrans timer was specified, pick one that's as long as the
231      # router lifetime. This ensures that no spurious ND retransmits
232      # will interfere with test expectations.
233      retranstimer = validity
234
235    # We don't want any routes in the main table. If the kernel doesn't support
236    # putting RA routes into per-interface tables, configure routing manually.
237    routerlifetime = validity if HAVE_AUTOCONF_TABLE else 0
238
239    ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") /
240          scapy.IPv6(src=lladdr, hlim=255) /
241          scapy.ICMPv6ND_RA(retranstimer=retranstimer,
242                            routerlifetime=routerlifetime) /
243          scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) /
244          scapy.ICMPv6NDOptPrefixInfo(prefix=cls.IPv6Prefix(netid),
245                                      prefixlen=64,
246                                      L=1, A=1,
247                                      validlifetime=validity,
248                                      preferredlifetime=validity))
249    posix.write(cls.tuns[netid].fileno(), str(ra))
250
251  @classmethod
252  def _RunSetupCommands(cls, netid, is_add):
253    for version in [4, 6]:
254      # Find out how to configure things.
255      iface = cls.GetInterfaceName(netid)
256      ifindex = cls.ifindices[netid]
257      macaddr = cls.RouterMacAddress(netid)
258      router = cls._RouterAddress(netid, version)
259      table = cls._TableForNetid(netid)
260
261      # Set up routing rules.
262      if HAVE_UID_ROUTING:
263        start, end = cls.UidRangeForNetid(netid)
264        cls.iproute.UidRangeRule(version, is_add, start, end, table,
265                                 cls.PRIORITY_UID)
266      cls.iproute.OifRule(version, is_add, iface, table, cls.PRIORITY_OIF)
267      cls.iproute.FwmarkRule(version, is_add, netid, table,
268                             cls.PRIORITY_FWMARK)
269
270      # Configure routing and addressing.
271      #
272      # IPv6 uses autoconf for everything, except if per-device autoconf routing
273      # tables are not supported, in which case the default route (only) is
274      # configured manually. For IPv4 we have to manually configure addresses,
275      # routes, and neighbour cache entries (since we don't reply to ARP or ND).
276      #
277      # Since deleting addresses also causes routes to be deleted, we need to
278      # be careful with ordering or the delete commands will fail with ENOENT.
279      do_routing = (version == 4 or cls.AUTOCONF_TABLE_OFFSET is None)
280      if is_add:
281        if version == 4:
282          cls.iproute.AddAddress(cls._MyIPv4Address(netid), 24, ifindex)
283          cls.iproute.AddNeighbour(version, router, macaddr, ifindex)
284        if do_routing:
285          cls.iproute.AddRoute(version, table, "default", 0, router, ifindex)
286          if version == 6:
287            cls.iproute.AddRoute(version, table,
288                                 cls.IPv6Prefix(netid), 64, None, ifindex)
289      else:
290        if do_routing:
291          cls.iproute.DelRoute(version, table, "default", 0, router, ifindex)
292          if version == 6:
293            cls.iproute.DelRoute(version, table,
294                                 cls.IPv6Prefix(netid), 64, None, ifindex)
295        if version == 4:
296          cls.iproute.DelNeighbour(version, router, macaddr, ifindex)
297          cls.iproute.DelAddress(cls._MyIPv4Address(netid), 24, ifindex)
298
299  @classmethod
300  def SetDefaultNetwork(cls, netid):
301    table = cls._TableForNetid(netid) if netid else None
302    for version in [4, 6]:
303      is_add = table is not None
304      cls.iproute.DefaultRule(version, is_add, table, cls.PRIORITY_DEFAULT)
305
306  @classmethod
307  def ClearDefaultNetwork(cls):
308    cls.SetDefaultNetwork(None)
309
310  @classmethod
311  def GetSysctl(cls, sysctl):
312    return open(sysctl, "r").read()
313
314  @classmethod
315  def SetSysctl(cls, sysctl, value):
316    # Only save each sysctl value the first time we set it. This is so we can
317    # set it to arbitrary values multiple times and still write it back
318    # correctly at the end.
319    if sysctl not in cls.saved_sysctls:
320      cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl)
321    open(sysctl, "w").write(str(value) + "\n")
322
323  @classmethod
324  def _RestoreSysctls(cls):
325    for sysctl, value in cls.saved_sysctls.iteritems():
326      try:
327        open(sysctl, "w").write(value)
328      except IOError:
329        pass
330
331  @classmethod
332  def _ICMPRatelimitFilename(cls, version):
333    return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit",
334                               6: "ipv6/icmp/ratelimit"}[version]
335
336  @classmethod
337  def _SetICMPRatelimit(cls, version, limit):
338    cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit)
339
340  @classmethod
341  def setUpClass(cls):
342    # This is per-class setup instead of per-testcase setup because shelling out
343    # to ip and iptables is slow, and because routing configuration doesn't
344    # change during the test.
345    cls.iproute = iproute.IPRoute()
346    cls.tuns = {}
347    cls.ifindices = {}
348    if HAVE_AUTOCONF_TABLE:
349      cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000)
350      cls.AUTOCONF_TABLE_OFFSET = -1000
351    else:
352      cls.AUTOCONF_TABLE_OFFSET = None
353
354    # Disable ICMP rate limits. These will be restored by _RestoreSysctls.
355    for version in [4, 6]:
356      cls._SetICMPRatelimit(version, 0)
357
358    for netid in cls.NETIDS:
359      cls.tuns[netid] = cls.CreateTunInterface(netid)
360      iface = cls.GetInterfaceName(netid)
361      cls.ifindices[netid] = net_test.GetInterfaceIndex(iface)
362
363      cls.SendRA(netid)
364      cls._RunSetupCommands(netid, True)
365
366    for version in [4, 6]:
367      cls.iproute.UnreachableRule(version, True, 1000)
368
369    # Uncomment to look around at interface and rule configuration while
370    # running in the background. (Once the test finishes running, all the
371    # interfaces and rules are gone.)
372    # time.sleep(30)
373
374  @classmethod
375  def tearDownClass(cls):
376    for version in [4, 6]:
377      try:
378        cls.iproute.UnreachableRule(version, False, 1000)
379      except IOError:
380        pass
381
382    for netid in cls.tuns:
383      cls._RunSetupCommands(netid, False)
384      cls.tuns[netid].close()
385    cls._RestoreSysctls()
386
387  def setUp(self):
388    self.ClearTunQueues()
389
390  def SetSocketMark(self, s, netid):
391    if netid is None:
392      netid = 0
393    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
394
395  def GetSocketMark(self, s):
396    return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
397
398  def ClearSocketMark(self, s):
399    self.SetSocketMark(s, 0)
400
401  def BindToDevice(self, s, iface):
402    if not iface:
403      iface = ""
404    s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface)
405
406  def SetUnicastInterface(self, s, ifindex):
407    # Otherwise, Python thinks it's a 1-byte option.
408    ifindex = struct.pack("!I", ifindex)
409
410    # Always set the IPv4 interface, because it will be used even on IPv6
411    # sockets if the destination address is a mapped address.
412    s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex)
413    if s.family == AF_INET6:
414      s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex)
415
416  def GetRemoteAddress(self, version):
417    return {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
418
419  def SelectInterface(self, s, netid, mode):
420    if mode == "uid":
421      raise ValueError("Can't change UID on an existing socket")
422    elif mode == "mark":
423      self.SetSocketMark(s, netid)
424    elif mode == "oif":
425      iface = self.GetInterfaceName(netid) if netid else ""
426      self.BindToDevice(s, iface)
427    elif mode == "ucast_oif":
428      self.SetUnicastInterface(s, self.ifindices.get(netid, 0))
429    else:
430      raise ValueError("Unknown interface selection mode %s" % mode)
431
432  def BuildSocket(self, version, constructor, netid, routing_mode):
433    uid = self.UidForNetid(netid) if routing_mode == "uid" else None
434    with net_test.RunAsUid(uid):
435      family = self.GetProtocolFamily(version)
436      s = constructor(family)
437
438    if routing_mode not in [None, "uid"]:
439      self.SelectInterface(s, netid, routing_mode)
440
441    return s
442
443  def SendOnNetid(self, version, s, dstaddr, dstport, netid, payload, cmsgs):
444    if netid is not None:
445      pktinfo = MakePktInfo(version, None, self.ifindices[netid])
446      cmsg_level, cmsg_name = {
447          4: (net_test.SOL_IP, IP_PKTINFO),
448          6: (net_test.SOL_IPV6, IPV6_PKTINFO)}[version]
449      cmsgs.append((cmsg_level, cmsg_name, pktinfo))
450    csocket.Sendmsg(s, (dstaddr, dstport), payload, cmsgs, csocket.MSG_CONFIRM)
451
452  def ReceiveEtherPacketOn(self, netid, packet):
453    posix.write(self.tuns[netid].fileno(), str(packet))
454
455  def ReceivePacketOn(self, netid, ip_packet):
456    routermac = self.RouterMacAddress(netid)
457    mymac = self.MyMacAddress(netid)
458    packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
459    self.ReceiveEtherPacketOn(netid, packet)
460
461  def ReadAllPacketsOn(self, netid, include_multicast=False):
462    packets = []
463    while True:
464      try:
465        packet = posix.read(self.tuns[netid].fileno(), 4096)
466        if not packet:
467          break
468        ether = scapy.Ether(packet)
469        # Multicast frames are frames where the first byte of the destination
470        # MAC address has 1 in the least-significant bit.
471        if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1:
472          packets.append(ether.payload)
473      except OSError, e:
474        # EAGAIN means there are no more packets waiting.
475        if re.match(e.message, os.strerror(errno.EAGAIN)):
476          break
477        # Anything else is unexpected.
478        else:
479          raise e
480    return packets
481
482  def ClearTunQueues(self):
483    # Keep reading packets on all netids until we get no packets on any of them.
484    waiting = None
485    while waiting != 0:
486      waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
487
488  def assertPacketMatches(self, expected, actual):
489    # The expected packet is just a rough sketch of the packet we expect to
490    # receive. For example, it doesn't contain fields we can't predict, such as
491    # initial TCP sequence numbers, or that depend on the host implementation
492    # and settings, such as TCP options. To check whether the packet matches
493    # what we expect, instead of just checking all the known fields one by one,
494    # we blank out fields in the actual packet and then compare the whole
495    # packets to each other as strings. Because we modify the actual packet,
496    # make a copy here.
497    actual = actual.copy()
498
499    # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
500    actualip = actual.getlayer("IP")
501    expectedip = expected.getlayer("IP")
502    if actualip and expectedip:
503      actualip.id = expectedip.id
504      actualip.flags &= 5
505      actualip.chksum = None  # Change the header, recalculate the checksum.
506
507    # Blank out UDP fields that we can't predict (e.g., the source port for
508    # kernel-originated packets).
509    actualudp = actual.getlayer("UDP")
510    expectedudp = expected.getlayer("UDP")
511    if actualudp and expectedudp:
512      if expectedudp.sport is None:
513        actualudp.sport = None
514        actualudp.chksum = None
515
516    # Since the TCP code below messes with options, recalculate the length.
517    if actualip:
518      actualip.len = None
519    actualipv6 = actual.getlayer("IPv6")
520    if actualipv6:
521      actualipv6.plen = None
522
523    # Blank out TCP fields that we can't predict.
524    actualtcp = actual.getlayer("TCP")
525    expectedtcp = expected.getlayer("TCP")
526    if actualtcp and expectedtcp:
527      actualtcp.dataofs = expectedtcp.dataofs
528      actualtcp.options = expectedtcp.options
529      actualtcp.window = expectedtcp.window
530      if expectedtcp.sport is None:
531        actualtcp.sport = None
532      if expectedtcp.seq is None:
533        actualtcp.seq = None
534      if expectedtcp.ack is None:
535        actualtcp.ack = None
536      actualtcp.chksum = None
537
538    # Serialize the packet so that expected packet fields that are only set when
539    # a packet is serialized e.g., the checksum) are filled in.
540    expected_real = expected.__class__(str(expected))
541    actual_real = actual.__class__(str(actual))
542    # repr() can be expensive. Call it only if the test is going to fail and we
543    # want to see the error.
544    if expected_real != actual_real:
545      self.assertEquals(repr(expected_real), repr(actual_real))
546
547  def PacketMatches(self, expected, actual):
548    try:
549      self.assertPacketMatches(expected, actual)
550      return True
551    except AssertionError:
552      return False
553
554  def ExpectNoPacketsOn(self, netid, msg):
555    packets = self.ReadAllPacketsOn(netid)
556    if packets:
557      firstpacket = repr(packets[0])
558    else:
559      firstpacket = ""
560    self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket)
561
562  def ExpectPacketOn(self, netid, msg, expected):
563    # To avoid confusion due to lots of ICMPv6 ND going on all the time, drop
564    # multicast packets unless the packet we expect to see is a multicast
565    # packet. For now the only tests that use this are IPv6.
566    ipv6 = expected.getlayer("IPv6")
567    if ipv6 and ipv6.dst.startswith("ff"):
568      include_multicast = True
569    else:
570      include_multicast = False
571
572    packets = self.ReadAllPacketsOn(netid, include_multicast=include_multicast)
573    self.assertTrue(packets, msg + ": received no packets")
574
575    # If we receive a packet that matches what we expected, return it.
576    for packet in packets:
577      if self.PacketMatches(expected, packet):
578        return packet
579
580    # None of the packets matched. Call assertPacketMatches to output a diff
581    # between the expected packet and the last packet we received. In theory,
582    # we'd output a diff to the packet that's the best match for what we
583    # expected, but this is good enough for now.
584    try:
585      self.assertPacketMatches(expected, packets[-1])
586    except Exception, e:
587      raise UnexpectedPacketError(
588          "%s: diff with last packet:\n%s" % (msg, e.message))
589
590  def Combinations(self, version):
591    """Produces a list of combinations to test."""
592    combinations = []
593
594    # Check packets addressed to the IP addresses of all our interfaces...
595    for dest_ip_netid in self.tuns:
596      ip_if = self.GetInterfaceName(dest_ip_netid)
597      myaddr = self.MyAddress(version, dest_ip_netid)
598      remoteaddr = self.GetRemoteAddress(version)
599
600      # ... coming in on all our interfaces.
601      for netid in self.tuns:
602        iif = self.GetInterfaceName(netid)
603        combinations.append((netid, iif, ip_if, myaddr, remoteaddr))
604
605    return combinations
606
607  def _FormatMessage(self, iif, ip_if, extra, desc, reply_desc):
608    msg = "Receiving %s on %s to %s IP, %s" % (desc, iif, ip_if, extra)
609    if reply_desc:
610      msg += ": Expecting %s on %s" % (reply_desc, iif)
611    else:
612      msg += ": Expecting no packets on %s" % iif
613    return msg
614
615  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
616    self.ReceivePacketOn(netid, packet)
617    if reply:
618      return self.ExpectPacketOn(netid, msg, reply)
619    else:
620      self.ExpectNoPacketsOn(netid, msg)
621      return None
622