1#!/usr/bin/python
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Base module for multinetwork tests."""
18
19import errno
20import fcntl
21import os
22import posix
23import random
24import re
25from socket import *  # pylint: disable=wildcard-import
26import struct
27import time
28
29from scapy import all as scapy
30
31import csocket
32import iproute
33import net_test
34
35
36IFF_TUN = 1
37IFF_TAP = 2
38IFF_NO_PI = 0x1000
39TUNSETIFF = 0x400454ca
40
41SO_BINDTODEVICE = 25
42
43# Setsockopt values.
44IP_UNICAST_IF = 50
45IPV6_MULTICAST_IF = 17
46IPV6_UNICAST_IF = 76
47
48# Cmsg values.
49IP_TTL = 2
50IPV6_2292PKTOPTIONS = 6
51IPV6_FLOWINFO = 11
52IPV6_HOPLIMIT = 52  # Different from IPV6_UNICAST_HOPS, this is cmsg only.
53
54
55AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table"
56IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
57IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
58
59HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL)
60
61
62class UnexpectedPacketError(AssertionError):
63  pass
64
65
66def MakePktInfo(version, addr, ifindex):
67  family = {4: AF_INET, 6: AF_INET6}[version]
68  if not addr:
69    addr = {4: "0.0.0.0", 6: "::"}[version]
70  if addr:
71    addr = inet_pton(family, addr)
72  if version == 6:
73    return csocket.In6Pktinfo((addr, ifindex)).Pack()
74  else:
75    return csocket.InPktinfo((ifindex, addr, "\x00" * 4)).Pack()
76
77
78class MultiNetworkBaseTest(net_test.NetworkTest):
79  """Base class for all multinetwork tests.
80
81  This class does not contain any test code, but contains code to set up and
82  tear a multi-network environment using multiple tun interfaces. The
83  environment is designed to be similar to a real Android device in terms of
84  rules and routes, and supports IPv4 and IPv6.
85
86  Tests wishing to use this environment should inherit from this class and
87  ensure that any setupClass, tearDownClass, setUp, and tearDown methods they
88  implement also call the superclass versions.
89  """
90
91  # Must be between 1 and 256, since we put them in MAC addresses and IIDs.
92  NETIDS = [100, 150, 200, 250]
93
94  # Stores sysctl values to write back when the test completes.
95  saved_sysctls = {}
96
97  # Wether to output setup commands.
98  DEBUG = False
99
100  # The size of our UID ranges.
101  UID_RANGE_SIZE = 1000
102
103  # Rule priorities.
104  PRIORITY_UID = 100
105  PRIORITY_OIF = 200
106  PRIORITY_FWMARK = 300
107  PRIORITY_IIF = 400
108  PRIORITY_DEFAULT = 999
109  PRIORITY_UNREACHABLE = 1000
110
111  # Actual device routing is more complicated, involving more than one rule
112  # per NetId, but here we make do with just one rule that selects the lower
113  # 16 bits.
114  NETID_FWMASK = 0xffff
115
116  # For convenience.
117  IPV4_ADDR = net_test.IPV4_ADDR
118  IPV6_ADDR = net_test.IPV6_ADDR
119  IPV4_ADDR2 = net_test.IPV4_ADDR2
120  IPV6_ADDR2 = net_test.IPV6_ADDR2
121  IPV4_PING = net_test.IPV4_PING
122  IPV6_PING = net_test.IPV6_PING
123
124  RA_VALIDITY = 300 # seconds
125
126  @classmethod
127  def UidRangeForNetid(cls, netid):
128    return (
129        cls.UID_RANGE_SIZE * netid,
130        cls.UID_RANGE_SIZE * (netid + 1) - 1
131    )
132
133  @classmethod
134  def UidForNetid(cls, netid):
135    if not netid:
136      return 0
137    return random.randint(*cls.UidRangeForNetid(netid))
138
139  @classmethod
140  def _TableForNetid(cls, netid):
141    if cls.AUTOCONF_TABLE_OFFSET and netid in cls.ifindices:
142      return cls.ifindices[netid] + (-cls.AUTOCONF_TABLE_OFFSET)
143    else:
144      return netid
145
146  @staticmethod
147  def GetInterfaceName(netid):
148    return "nettest%d" % netid
149
150  @staticmethod
151  def RouterMacAddress(netid):
152    return "02:00:00:00:%02x:00" % netid
153
154  @staticmethod
155  def MyMacAddress(netid):
156    return "02:00:00:00:%02x:01" % netid
157
158  @staticmethod
159  def _RouterAddress(netid, version):
160    if version == 6:
161      return "fe80::%02x00" % netid
162    elif version == 4:
163      return "10.0.%d.1" % netid
164    else:
165      raise ValueError("Don't support IPv%s" % version)
166
167  @classmethod
168  def _MyIPv4Address(cls, netid):
169    return "10.0.%d.2" % netid
170
171  @classmethod
172  def _MyIPv6Address(cls, netid):
173    return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False)
174
175  @classmethod
176  def MyAddress(cls, version, netid):
177    return {4: cls._MyIPv4Address(netid),
178            5: cls._MyIPv4Address(netid),
179            6: cls._MyIPv6Address(netid)}[version]
180
181  @classmethod
182  def MySocketAddress(cls, version, netid):
183    addr = cls.MyAddress(version, netid)
184    return "::ffff:" + addr if version == 5 else addr
185
186  @classmethod
187  def MyLinkLocalAddress(cls, netid):
188    return net_test.GetLinkAddress(cls.GetInterfaceName(netid), True)
189
190  @staticmethod
191  def OnlinkPrefixLen(version):
192    return {4: 24, 6: 64}[version]
193
194  @staticmethod
195  def OnlinkPrefix(version, netid):
196    return {4: "10.0.%d.0" % netid,
197            6: "2001:db8:%02x::" % netid}[version]
198
199  @staticmethod
200  def GetRandomDestination(prefix):
201    if "." in prefix:
202      return prefix + "%d.%d" % (random.randint(0, 255), random.randint(0, 255))
203    else:
204      return prefix + "%x:%x" % (random.randint(0, 65535),
205                                 random.randint(0, 65535))
206
207  def GetProtocolFamily(self, version):
208    return {4: AF_INET, 6: AF_INET6}[version]
209
210  @classmethod
211  def CreateTunInterface(cls, netid):
212    iface = cls.GetInterfaceName(netid)
213    try:
214      f = open("/dev/net/tun", "r+b")
215    except IOError:
216      f = open("/dev/tun", "r+b")
217    ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI)
218    ifr += "\x00" * (40 - len(ifr))
219    fcntl.ioctl(f, TUNSETIFF, ifr)
220    # Give ourselves a predictable MAC address.
221    net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid))
222    # Disable DAD so we don't have to wait for it.
223    cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % iface, 0)
224    # Set accept_ra to 2, because that's what we use.
225    cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_ra" % iface, 2)
226    net_test.SetInterfaceUp(iface)
227    net_test.SetNonBlocking(f)
228    return f
229
230  @classmethod
231  def SendRA(cls, netid, retranstimer=None, reachabletime=0, options=()):
232    validity = cls.RA_VALIDITY # seconds
233    macaddr = cls.RouterMacAddress(netid)
234    lladdr = cls._RouterAddress(netid, 6)
235
236    if retranstimer is None:
237      # If no retrans timer was specified, pick one that's as long as the
238      # router lifetime. This ensures that no spurious ND retransmits
239      # will interfere with test expectations.
240      retranstimer = validity * 1000  # Lifetime is in s, retrans timer in ms.
241
242    # We don't want any routes in the main table. If the kernel doesn't support
243    # putting RA routes into per-interface tables, configure routing manually.
244    routerlifetime = validity if HAVE_AUTOCONF_TABLE else 0
245
246    ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") /
247          scapy.IPv6(src=lladdr, hlim=255) /
248          scapy.ICMPv6ND_RA(reachabletime=reachabletime,
249                            retranstimer=retranstimer,
250                            routerlifetime=routerlifetime) /
251          scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) /
252          scapy.ICMPv6NDOptPrefixInfo(prefix=cls.OnlinkPrefix(6, netid),
253                                      prefixlen=cls.OnlinkPrefixLen(6),
254                                      L=1, A=1,
255                                      validlifetime=validity,
256                                      preferredlifetime=validity))
257    for option in options:
258      ra /= option
259    posix.write(cls.tuns[netid].fileno(), str(ra))
260
261  @classmethod
262  def _RunSetupCommands(cls, netid, is_add):
263    for version in [4, 6]:
264      # Find out how to configure things.
265      iface = cls.GetInterfaceName(netid)
266      ifindex = cls.ifindices[netid]
267      macaddr = cls.RouterMacAddress(netid)
268      router = cls._RouterAddress(netid, version)
269      table = cls._TableForNetid(netid)
270
271      # Set up routing rules.
272      start, end = cls.UidRangeForNetid(netid)
273      cls.iproute.UidRangeRule(version, is_add, start, end, table,
274                               cls.PRIORITY_UID)
275      cls.iproute.OifRule(version, is_add, iface, table, cls.PRIORITY_OIF)
276      cls.iproute.FwmarkRule(version, is_add, netid, cls.NETID_FWMASK, table,
277                             cls.PRIORITY_FWMARK)
278
279      # Configure routing and addressing.
280      #
281      # IPv6 uses autoconf for everything, except if per-device autoconf routing
282      # tables are not supported, in which case the default route (only) is
283      # configured manually. For IPv4 we have to manually configure addresses,
284      # routes, and neighbour cache entries (since we don't reply to ARP or ND).
285      #
286      # Since deleting addresses also causes routes to be deleted, we need to
287      # be careful with ordering or the delete commands will fail with ENOENT.
288      #
289      # A real Android system will have both IPv4 and IPv6 routes for
290      # directly-connected subnets in the per-interface routing tables. Ensure
291      # we create those as well.
292      do_routing = (version == 4 or cls.AUTOCONF_TABLE_OFFSET is None)
293      if is_add:
294        if version == 4:
295          cls.iproute.AddAddress(cls._MyIPv4Address(netid),
296                                 cls.OnlinkPrefixLen(4), ifindex)
297          cls.iproute.AddNeighbour(version, router, macaddr, ifindex)
298        if do_routing:
299          cls.iproute.AddRoute(version, table,
300                               cls.OnlinkPrefix(version, netid),
301                               cls.OnlinkPrefixLen(version), None, ifindex)
302          cls.iproute.AddRoute(version, table, "default", 0, router, ifindex)
303      else:
304        if do_routing:
305          cls.iproute.DelRoute(version, table, "default", 0, router, ifindex)
306          cls.iproute.DelRoute(version, table,
307                               cls.OnlinkPrefix(version, netid),
308                               cls.OnlinkPrefixLen(version), None, ifindex)
309        if version == 4:
310          cls.iproute.DelNeighbour(version, router, macaddr, ifindex)
311          cls.iproute.DelAddress(cls._MyIPv4Address(netid),
312                                 cls.OnlinkPrefixLen(4), ifindex)
313
314  @classmethod
315  def SetMarkReflectSysctls(cls, value):
316    """Makes kernel-generated replies use the mark of the original packet."""
317    cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
318    cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
319
320  @classmethod
321  def _SetInboundMarking(cls, netid, iface, is_add):
322    for version in [4, 6]:
323      # Run iptables to set up incoming packet marking.
324      add_del = "-A" if is_add else "-D"
325      iptables = {4: "iptables", 6: "ip6tables"}[version]
326      args = "%s INPUT -t mangle -i %s -j MARK --set-mark %d" % (
327          add_del, iface, netid)
328      if net_test.RunIptablesCommand(version, args):
329        raise ConfigurationError("Setup command failed: %s" % args)
330
331  @classmethod
332  def SetInboundMarks(cls, is_add):
333    for netid in cls.tuns:
334      cls._SetInboundMarking(netid, cls.GetInterfaceName(netid), is_add)
335
336  @classmethod
337  def SetDefaultNetwork(cls, netid):
338    table = cls._TableForNetid(netid) if netid else None
339    for version in [4, 6]:
340      is_add = table is not None
341      cls.iproute.DefaultRule(version, is_add, table, cls.PRIORITY_DEFAULT)
342
343  @classmethod
344  def ClearDefaultNetwork(cls):
345    cls.SetDefaultNetwork(None)
346
347  @classmethod
348  def GetSysctl(cls, sysctl):
349    return open(sysctl, "r").read()
350
351  @classmethod
352  def SetSysctl(cls, sysctl, value):
353    # Only save each sysctl value the first time we set it. This is so we can
354    # set it to arbitrary values multiple times and still write it back
355    # correctly at the end.
356    if sysctl not in cls.saved_sysctls:
357      cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl)
358    open(sysctl, "w").write(str(value) + "\n")
359
360  @classmethod
361  def SetIPv6SysctlOnAllIfaces(cls, sysctl, value):
362    for netid in cls.tuns:
363      iface = cls.GetInterfaceName(netid)
364      name = "/proc/sys/net/ipv6/conf/%s/%s" % (iface, sysctl)
365      cls.SetSysctl(name, value)
366
367  @classmethod
368  def _RestoreSysctls(cls):
369    for sysctl, value in cls.saved_sysctls.items():
370      try:
371        open(sysctl, "w").write(value)
372      except IOError:
373        pass
374
375  @classmethod
376  def _ICMPRatelimitFilename(cls, version):
377    return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit",
378                               6: "ipv6/icmp/ratelimit"}[version]
379
380  @classmethod
381  def _SetICMPRatelimit(cls, version, limit):
382    cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit)
383
384  @classmethod
385  def setUpClass(cls):
386    # This is per-class setup instead of per-testcase setup because shelling out
387    # to ip and iptables is slow, and because routing configuration doesn't
388    # change during the test.
389    cls.iproute = iproute.IPRoute()
390    cls.tuns = {}
391    cls.ifindices = {}
392    if HAVE_AUTOCONF_TABLE:
393      cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000)
394      cls.AUTOCONF_TABLE_OFFSET = -1000
395    else:
396      cls.AUTOCONF_TABLE_OFFSET = None
397
398    # Disable ICMP rate limits. These will be restored by _RestoreSysctls.
399    for version in [4, 6]:
400      cls._SetICMPRatelimit(version, 0)
401
402    for version in [4, 6]:
403      cls.iproute.UnreachableRule(version, True, cls.PRIORITY_UNREACHABLE)
404
405    for netid in cls.NETIDS:
406      cls.tuns[netid] = cls.CreateTunInterface(netid)
407      iface = cls.GetInterfaceName(netid)
408      cls.ifindices[netid] = net_test.GetInterfaceIndex(iface)
409
410      cls.SendRA(netid)
411      cls._RunSetupCommands(netid, True)
412
413    # Don't print lots of "device foo entered promiscuous mode" warnings.
414    cls.loglevel = cls.GetConsoleLogLevel()
415    cls.SetConsoleLogLevel(net_test.KERN_INFO)
416
417    # When running on device, don't send connections through FwmarkServer.
418    os.environ["ANDROID_NO_USE_FWMARK_CLIENT"] = "1"
419
420    # Uncomment to look around at interface and rule configuration while
421    # running in the background. (Once the test finishes running, all the
422    # interfaces and rules are gone.)
423    # time.sleep(30)
424
425  @classmethod
426  def tearDownClass(cls):
427    del os.environ["ANDROID_NO_USE_FWMARK_CLIENT"]
428
429    for version in [4, 6]:
430      try:
431        cls.iproute.UnreachableRule(version, False, cls.PRIORITY_UNREACHABLE)
432      except IOError:
433        pass
434
435    for netid in cls.tuns:
436      cls._RunSetupCommands(netid, False)
437      cls.tuns[netid].close()
438
439    cls._RestoreSysctls()
440    cls.SetConsoleLogLevel(cls.loglevel)
441
442  def setUp(self):
443    self.ClearTunQueues()
444
445  def SetSocketMark(self, s, netid):
446    if netid is None:
447      netid = 0
448    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
449
450  def GetSocketMark(self, s):
451    return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
452
453  def ClearSocketMark(self, s):
454    self.SetSocketMark(s, 0)
455
456  def BindToDevice(self, s, iface):
457    if not iface:
458      iface = ""
459    s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface)
460
461  def SetUnicastInterface(self, s, ifindex):
462    # Otherwise, Python thinks it's a 1-byte option.
463    ifindex = struct.pack("!I", ifindex)
464
465    # Always set the IPv4 interface, because it will be used even on IPv6
466    # sockets if the destination address is a mapped address.
467    s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex)
468    if s.family == AF_INET6:
469      s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex)
470
471  def GetRemoteAddress(self, version):
472    return {4: self.IPV4_ADDR,
473            5: self.IPV4_ADDR,  # see GetRemoteSocketAddress()
474            6: self.IPV6_ADDR}[version]
475
476  def GetRemoteSocketAddress(self, version):
477    addr = self.GetRemoteAddress(version)
478    return "::ffff:" + addr if version == 5 else addr
479
480  def GetOtherRemoteSocketAddress(self, version):
481    return {4: self.IPV4_ADDR2,
482            5: "::ffff:" + self.IPV4_ADDR2,
483            6: self.IPV6_ADDR2}[version]
484
485  def SelectInterface(self, s, netid, mode):
486    if mode == "uid":
487      os.fchown(s.fileno(), self.UidForNetid(netid), -1)
488    elif mode == "mark":
489      self.SetSocketMark(s, netid)
490    elif mode == "oif":
491      iface = self.GetInterfaceName(netid) if netid else ""
492      self.BindToDevice(s, iface)
493    elif mode == "ucast_oif":
494      self.SetUnicastInterface(s, self.ifindices.get(netid, 0))
495    else:
496      raise ValueError("Unknown interface selection mode %s" % mode)
497
498  def BuildSocket(self, version, constructor, netid, routing_mode):
499    if version == 5: version = 6
500    s = constructor(self.GetProtocolFamily(version))
501
502    if routing_mode not in [None, "uid"]:
503      self.SelectInterface(s, netid, routing_mode)
504    elif routing_mode == "uid":
505      os.fchown(s.fileno(), self.UidForNetid(netid), -1)
506
507    return s
508
509  def RandomNetid(self, exclude=None):
510    """Return a random netid from the list of netids
511
512    Args:
513      exclude: a netid or list of netids that should not be chosen
514    """
515    if exclude is None:
516      exclude = []
517    elif isinstance(exclude, int):
518        exclude = [exclude]
519    diff = [netid for netid in self.NETIDS if netid not in exclude]
520    return random.choice(diff)
521
522  def SendOnNetid(self, version, s, dstaddr, dstport, netid, payload, cmsgs):
523    if netid is not None:
524      pktinfo = MakePktInfo(version, None, self.ifindices[netid])
525      cmsg_level, cmsg_name = {
526          4: (net_test.SOL_IP, csocket.IP_PKTINFO),
527          6: (net_test.SOL_IPV6, csocket.IPV6_PKTINFO)}[version]
528      cmsgs.append((cmsg_level, cmsg_name, pktinfo))
529    csocket.Sendmsg(s, (dstaddr, dstport), payload, cmsgs, csocket.MSG_CONFIRM)
530
531  def ReceiveEtherPacketOn(self, netid, packet):
532    posix.write(self.tuns[netid].fileno(), str(packet))
533
534  def ReceivePacketOn(self, netid, ip_packet):
535    routermac = self.RouterMacAddress(netid)
536    mymac = self.MyMacAddress(netid)
537    packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
538    self.ReceiveEtherPacketOn(netid, packet)
539
540  def ReadAllPacketsOn(self, netid, include_multicast=False):
541    """Return all queued packets on a netid as a list.
542
543    Args:
544      netid: The netid from which to read packets
545      include_multicast: A boolean, whether to remove multicast packets
546        (default=False)
547    """
548    packets = []
549    retries = 0
550    max_retries = 1
551    while True:
552      try:
553        packet = posix.read(self.tuns[netid].fileno(), 4096)
554        if not packet:
555          break
556        ether = scapy.Ether(packet)
557        # Multicast frames are frames where the first byte of the destination
558        # MAC address has 1 in the least-significant bit.
559        if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1:
560          packets.append(ether.payload)
561      except OSError as e:
562        # EAGAIN means there are no more packets waiting.
563        if re.match(e.message, os.strerror(errno.EAGAIN)):
564          # If we didn't see any packets, try again for good luck.
565          if not packets and retries < max_retries:
566            time.sleep(0.01)
567            retries += 1
568            continue
569          else:
570            break
571        # Anything else is unexpected.
572        else:
573          raise e
574    return packets
575
576  def InvalidateDstCache(self, version, netid):
577    """Invalidates destination cache entries of sockets on the specified table.
578
579    Creates and then deletes a low-priority throw route in the table for the
580    given netid, which invalidates the destination cache entries of any sockets
581    that refer to routes in that table.
582
583    The fact that this method actually invalidates destination cache entries is
584    tested by OutgoingTest#testIPv[46]Remarking, which checks that the kernel
585    does not re-route sockets when they are remarked, but does re-route them if
586    this method is called.
587
588    Args:
589      version: The IP version, 4 or 6.
590      netid: The netid to invalidate dst caches on.
591    """
592    iface = self.GetInterfaceName(netid)
593    ifindex = self.ifindices[netid]
594    table = self._TableForNetid(netid)
595    for action in [iproute.RTM_NEWROUTE, iproute.RTM_DELROUTE]:
596      self.iproute._Route(version, iproute.RTPROT_STATIC, action, table,
597                          "default", 0, nexthop=None, dev=None, mark=None,
598                          uid=None, route_type=iproute.RTN_THROW,
599                          priority=100000)
600
601  def ClearTunQueues(self):
602    # Keep reading packets on all netids until we get no packets on any of them.
603    waiting = None
604    while waiting != 0:
605      waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
606
607  def assertPacketMatches(self, expected, actual):
608    # The expected packet is just a rough sketch of the packet we expect to
609    # receive. For example, it doesn't contain fields we can't predict, such as
610    # initial TCP sequence numbers, or that depend on the host implementation
611    # and settings, such as TCP options. To check whether the packet matches
612    # what we expect, instead of just checking all the known fields one by one,
613    # we blank out fields in the actual packet and then compare the whole
614    # packets to each other as strings. Because we modify the actual packet,
615    # make a copy here.
616    actual = actual.copy()
617
618    # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
619    actualip = actual.getlayer("IP")
620    expectedip = expected.getlayer("IP")
621    if actualip and expectedip:
622      actualip.id = expectedip.id
623      actualip.flags &= 5
624      actualip.chksum = None  # Change the header, recalculate the checksum.
625
626    # Blank out the flow label, since new kernels randomize it by default.
627    actualipv6 = actual.getlayer("IPv6")
628    expectedipv6 = expected.getlayer("IPv6")
629    if actualipv6 and expectedipv6:
630      actualipv6.fl = expectedipv6.fl
631
632    # Blank out UDP fields that we can't predict (e.g., the source port for
633    # kernel-originated packets).
634    actualudp = actual.getlayer("UDP")
635    expectedudp = expected.getlayer("UDP")
636    if actualudp and expectedudp:
637      if expectedudp.sport is None:
638        actualudp.sport = None
639        actualudp.chksum = None
640      elif actualudp.chksum == 0xffff:
641        # Scapy does not appear to change 0 to 0xffff as required by RFC 768.
642        actualudp.chksum = 0
643
644    # Since the TCP code below messes with options, recalculate the length.
645    if actualip:
646      actualip.len = None
647    if actualipv6:
648      actualipv6.plen = None
649
650    # Blank out TCP fields that we can't predict.
651    actualtcp = actual.getlayer("TCP")
652    expectedtcp = expected.getlayer("TCP")
653    if actualtcp and expectedtcp:
654      actualtcp.dataofs = expectedtcp.dataofs
655      actualtcp.options = expectedtcp.options
656      actualtcp.window = expectedtcp.window
657      if expectedtcp.sport is None:
658        actualtcp.sport = None
659      if expectedtcp.seq is None:
660        actualtcp.seq = None
661      if expectedtcp.ack is None:
662        actualtcp.ack = None
663      actualtcp.chksum = None
664
665    # Serialize the packet so that expected packet fields that are only set when
666    # a packet is serialized e.g., the checksum) are filled in.
667    expected_real = expected.__class__(str(expected))
668    actual_real = actual.__class__(str(actual))
669    # repr() can be expensive. Call it only if the test is going to fail and we
670    # want to see the error.
671    if expected_real != actual_real:
672      self.assertEqual(repr(expected_real), repr(actual_real))
673
674  def PacketMatches(self, expected, actual):
675    try:
676      self.assertPacketMatches(expected, actual)
677      return True
678    except AssertionError:
679      return False
680
681  def ExpectNoPacketsOn(self, netid, msg):
682    packets = self.ReadAllPacketsOn(netid)
683    if packets:
684      firstpacket = repr(packets[0])
685    else:
686      firstpacket = ""
687    self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket)
688
689  def ExpectPacketOn(self, netid, msg, expected):
690    # To avoid confusion due to lots of ICMPv6 ND going on all the time, drop
691    # multicast packets unless the packet we expect to see is a multicast
692    # packet. For now the only tests that use this are IPv6.
693    ipv6 = expected.getlayer("IPv6")
694    if ipv6 and ipv6.dst.startswith("ff"):
695      include_multicast = True
696    else:
697      include_multicast = False
698
699    packets = self.ReadAllPacketsOn(netid, include_multicast=include_multicast)
700    self.assertTrue(packets, msg + ": received no packets")
701
702    # If we receive a packet that matches what we expected, return it.
703    for packet in packets:
704      if self.PacketMatches(expected, packet):
705        return packet
706
707    # None of the packets matched. Call assertPacketMatches to output a diff
708    # between the expected packet and the last packet we received. In theory,
709    # we'd output a diff to the packet that's the best match for what we
710    # expected, but this is good enough for now.
711    try:
712      self.assertPacketMatches(expected, packets[-1])
713    except Exception as e:
714      raise UnexpectedPacketError(
715          "%s: diff with last packet:\n%s" % (msg, e.message))
716
717  def Combinations(self, version):
718    """Produces a list of combinations to test."""
719    combinations = []
720
721    # Check packets addressed to the IP addresses of all our interfaces...
722    for dest_ip_netid in self.tuns:
723      ip_if = self.GetInterfaceName(dest_ip_netid)
724      myaddr = self.MyAddress(version, dest_ip_netid)
725      prefix = {4: "172.22.", 6: "2001:db8:aaaa::"}[version]
726      remoteaddr = self.GetRandomDestination(prefix)
727
728      # ... coming in on all our interfaces.
729      for netid in self.tuns:
730        iif = self.GetInterfaceName(netid)
731        combinations.append((netid, iif, ip_if, myaddr, remoteaddr))
732
733    return combinations
734
735  def _FormatMessage(self, iif, ip_if, extra, desc, reply_desc):
736    msg = "Receiving %s on %s to %s IP, %s" % (desc, iif, ip_if, extra)
737    if reply_desc:
738      msg += ": Expecting %s on %s" % (reply_desc, iif)
739    else:
740      msg += ": Expecting no packets on %s" % iif
741    return msg
742
743  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
744    self.ReceivePacketOn(netid, packet)
745    if reply:
746      return self.ExpectPacketOn(netid, msg, reply)
747    else:
748      self.ExpectNoPacketsOn(netid, msg)
749      return None
750
751
752class InboundMarkingTest(MultiNetworkBaseTest):
753  """Class that automatically sets up inbound marking."""
754
755  @classmethod
756  def setUpClass(cls):
757    super(InboundMarkingTest, cls).setUpClass()
758    cls.SetInboundMarks(True)
759
760  @classmethod
761  def tearDownClass(cls):
762    cls.SetInboundMarks(False)
763    super(InboundMarkingTest, cls).tearDownClass()
764