1#!/usr/bin/python 2# 3# Copyright 2014 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Base module for multinetwork tests.""" 18 19import errno 20import fcntl 21import os 22import posix 23import random 24import re 25from socket import * # pylint: disable=wildcard-import 26import struct 27 28from scapy import all as scapy 29 30import csocket 31import cstruct 32import iproute 33import net_test 34 35 36IFF_TUN = 1 37IFF_TAP = 2 38IFF_NO_PI = 0x1000 39TUNSETIFF = 0x400454ca 40 41SO_BINDTODEVICE = 25 42 43# Setsockopt values. 44IP_UNICAST_IF = 50 45IPV6_MULTICAST_IF = 17 46IPV6_UNICAST_IF = 76 47 48# Cmsg values. 49IP_TTL = 2 50IP_PKTINFO = 8 51IPV6_2292PKTOPTIONS = 6 52IPV6_FLOWINFO = 11 53IPV6_PKTINFO = 50 54IPV6_HOPLIMIT = 52 # Different from IPV6_UNICAST_HOPS, this is cmsg only. 55 56# Data structures. 57# These aren't constants, they're classes. So, pylint: disable=invalid-name 58InPktinfo = cstruct.Struct("in_pktinfo", "@i4s4s", "ifindex spec_dst addr") 59In6Pktinfo = cstruct.Struct("in6_pktinfo", "@16si", "addr ifindex") 60 61 62def HaveUidRouting(): 63 """Checks whether the kernel supports UID routing.""" 64 # Create a rule with the UID range selector. If the kernel doesn't understand 65 # the selector, it will create a rule with no selectors. 66 try: 67 iproute.IPRoute().UidRangeRule(6, True, 1000, 2000, 100, 10000) 68 except IOError: 69 return False 70 71 # Dump all the rules. If we find a rule using the UID range selector, then the 72 # kernel supports UID range routing. 73 rules = iproute.IPRoute().DumpRules(6) 74 result = any("FRA_UID_START" in attrs for rule, attrs in rules) 75 76 # Delete the rule. 77 iproute.IPRoute().UidRangeRule(6, False, 1000, 2000, 100, 10000) 78 return result 79 80AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table" 81 82HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL) 83HAVE_UID_ROUTING = HaveUidRouting() 84 85 86class UnexpectedPacketError(AssertionError): 87 pass 88 89 90def MakePktInfo(version, addr, ifindex): 91 family = {4: AF_INET, 6: AF_INET6}[version] 92 if not addr: 93 addr = {4: "0.0.0.0", 6: "::"}[version] 94 if addr: 95 addr = inet_pton(family, addr) 96 if version == 6: 97 return In6Pktinfo((addr, ifindex)).Pack() 98 else: 99 return InPktinfo((ifindex, addr, "\x00" * 4)).Pack() 100 101 102class MultiNetworkBaseTest(net_test.NetworkTest): 103 104 """Base class for all multinetwork tests. 105 106 This class does not contain any test code, but contains code to set up and 107 tear a multi-network environment using multiple tun interfaces. The 108 environment is designed to be similar to a real Android device in terms of 109 rules and routes, and supports IPv4 and IPv6. 110 111 Tests wishing to use this environment should inherit from this class and 112 ensure that any setupClass, tearDownClass, setUp, and tearDown methods they 113 implement also call the superclass versions. 114 """ 115 116 # Must be between 1 and 256, since we put them in MAC addresses and IIDs. 117 NETIDS = [100, 150, 200, 250] 118 119 # Stores sysctl values to write back when the test completes. 120 saved_sysctls = {} 121 122 # Wether to output setup commands. 123 DEBUG = False 124 125 # The size of our UID ranges. 126 UID_RANGE_SIZE = 1000 127 128 # Rule priorities. 129 PRIORITY_UID = 100 130 PRIORITY_OIF = 200 131 PRIORITY_FWMARK = 300 132 PRIORITY_DEFAULT = 999 133 PRIORITY_UNREACHABLE = 1000 134 135 # For convenience. 136 IPV4_ADDR = net_test.IPV4_ADDR 137 IPV6_ADDR = net_test.IPV6_ADDR 138 IPV4_PING = net_test.IPV4_PING 139 IPV6_PING = net_test.IPV6_PING 140 141 @classmethod 142 def UidRangeForNetid(cls, netid): 143 return ( 144 cls.UID_RANGE_SIZE * netid, 145 cls.UID_RANGE_SIZE * (netid + 1) - 1 146 ) 147 148 @classmethod 149 def UidForNetid(cls, netid): 150 return random.randint(*cls.UidRangeForNetid(netid)) 151 152 @classmethod 153 def _TableForNetid(cls, netid): 154 if cls.AUTOCONF_TABLE_OFFSET and netid in cls.ifindices: 155 return cls.ifindices[netid] + (-cls.AUTOCONF_TABLE_OFFSET) 156 else: 157 return netid 158 159 @staticmethod 160 def GetInterfaceName(netid): 161 return "nettest%d" % netid 162 163 @staticmethod 164 def RouterMacAddress(netid): 165 return "02:00:00:00:%02x:00" % netid 166 167 @staticmethod 168 def MyMacAddress(netid): 169 return "02:00:00:00:%02x:01" % netid 170 171 @staticmethod 172 def _RouterAddress(netid, version): 173 if version == 6: 174 return "fe80::%02x00" % netid 175 elif version == 4: 176 return "10.0.%d.1" % netid 177 else: 178 raise ValueError("Don't support IPv%s" % version) 179 180 @classmethod 181 def _MyIPv4Address(cls, netid): 182 return "10.0.%d.2" % netid 183 184 @classmethod 185 def _MyIPv6Address(cls, netid): 186 return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False) 187 188 @classmethod 189 def MyAddress(cls, version, netid): 190 return {4: cls._MyIPv4Address(netid), 191 6: cls._MyIPv6Address(netid)}[version] 192 193 @staticmethod 194 def IPv6Prefix(netid): 195 return "2001:db8:%02x::" % netid 196 197 @staticmethod 198 def GetRandomDestination(prefix): 199 if "." in prefix: 200 return prefix + "%d.%d" % (random.randint(0, 31), random.randint(0, 255)) 201 else: 202 return prefix + "%x:%x" % (random.randint(0, 65535), 203 random.randint(0, 65535)) 204 205 def GetProtocolFamily(self, version): 206 return {4: AF_INET, 6: AF_INET6}[version] 207 208 @classmethod 209 def CreateTunInterface(cls, netid): 210 iface = cls.GetInterfaceName(netid) 211 f = open("/dev/net/tun", "r+b") 212 ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI) 213 ifr += "\x00" * (40 - len(ifr)) 214 fcntl.ioctl(f, TUNSETIFF, ifr) 215 # Give ourselves a predictable MAC address. 216 net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid)) 217 # Disable DAD so we don't have to wait for it. 218 cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % iface, 0) 219 net_test.SetInterfaceUp(iface) 220 net_test.SetNonBlocking(f) 221 return f 222 223 @classmethod 224 def SendRA(cls, netid, retranstimer=None): 225 validity = 300 # seconds 226 macaddr = cls.RouterMacAddress(netid) 227 lladdr = cls._RouterAddress(netid, 6) 228 229 if retranstimer is None: 230 # If no retrans timer was specified, pick one that's as long as the 231 # router lifetime. This ensures that no spurious ND retransmits 232 # will interfere with test expectations. 233 retranstimer = validity 234 235 # We don't want any routes in the main table. If the kernel doesn't support 236 # putting RA routes into per-interface tables, configure routing manually. 237 routerlifetime = validity if HAVE_AUTOCONF_TABLE else 0 238 239 ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") / 240 scapy.IPv6(src=lladdr, hlim=255) / 241 scapy.ICMPv6ND_RA(retranstimer=retranstimer, 242 routerlifetime=routerlifetime) / 243 scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) / 244 scapy.ICMPv6NDOptPrefixInfo(prefix=cls.IPv6Prefix(netid), 245 prefixlen=64, 246 L=1, A=1, 247 validlifetime=validity, 248 preferredlifetime=validity)) 249 posix.write(cls.tuns[netid].fileno(), str(ra)) 250 251 @classmethod 252 def _RunSetupCommands(cls, netid, is_add): 253 for version in [4, 6]: 254 # Find out how to configure things. 255 iface = cls.GetInterfaceName(netid) 256 ifindex = cls.ifindices[netid] 257 macaddr = cls.RouterMacAddress(netid) 258 router = cls._RouterAddress(netid, version) 259 table = cls._TableForNetid(netid) 260 261 # Set up routing rules. 262 if HAVE_UID_ROUTING: 263 start, end = cls.UidRangeForNetid(netid) 264 cls.iproute.UidRangeRule(version, is_add, start, end, table, 265 cls.PRIORITY_UID) 266 cls.iproute.OifRule(version, is_add, iface, table, cls.PRIORITY_OIF) 267 cls.iproute.FwmarkRule(version, is_add, netid, table, 268 cls.PRIORITY_FWMARK) 269 270 # Configure routing and addressing. 271 # 272 # IPv6 uses autoconf for everything, except if per-device autoconf routing 273 # tables are not supported, in which case the default route (only) is 274 # configured manually. For IPv4 we have to manually configure addresses, 275 # routes, and neighbour cache entries (since we don't reply to ARP or ND). 276 # 277 # Since deleting addresses also causes routes to be deleted, we need to 278 # be careful with ordering or the delete commands will fail with ENOENT. 279 do_routing = (version == 4 or cls.AUTOCONF_TABLE_OFFSET is None) 280 if is_add: 281 if version == 4: 282 cls.iproute.AddAddress(cls._MyIPv4Address(netid), 24, ifindex) 283 cls.iproute.AddNeighbour(version, router, macaddr, ifindex) 284 if do_routing: 285 cls.iproute.AddRoute(version, table, "default", 0, router, ifindex) 286 if version == 6: 287 cls.iproute.AddRoute(version, table, 288 cls.IPv6Prefix(netid), 64, None, ifindex) 289 else: 290 if do_routing: 291 cls.iproute.DelRoute(version, table, "default", 0, router, ifindex) 292 if version == 6: 293 cls.iproute.DelRoute(version, table, 294 cls.IPv6Prefix(netid), 64, None, ifindex) 295 if version == 4: 296 cls.iproute.DelNeighbour(version, router, macaddr, ifindex) 297 cls.iproute.DelAddress(cls._MyIPv4Address(netid), 24, ifindex) 298 299 @classmethod 300 def SetDefaultNetwork(cls, netid): 301 table = cls._TableForNetid(netid) if netid else None 302 for version in [4, 6]: 303 is_add = table is not None 304 cls.iproute.DefaultRule(version, is_add, table, cls.PRIORITY_DEFAULT) 305 306 @classmethod 307 def ClearDefaultNetwork(cls): 308 cls.SetDefaultNetwork(None) 309 310 @classmethod 311 def GetSysctl(cls, sysctl): 312 return open(sysctl, "r").read() 313 314 @classmethod 315 def SetSysctl(cls, sysctl, value): 316 # Only save each sysctl value the first time we set it. This is so we can 317 # set it to arbitrary values multiple times and still write it back 318 # correctly at the end. 319 if sysctl not in cls.saved_sysctls: 320 cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl) 321 open(sysctl, "w").write(str(value) + "\n") 322 323 @classmethod 324 def _RestoreSysctls(cls): 325 for sysctl, value in cls.saved_sysctls.iteritems(): 326 try: 327 open(sysctl, "w").write(value) 328 except IOError: 329 pass 330 331 @classmethod 332 def _ICMPRatelimitFilename(cls, version): 333 return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit", 334 6: "ipv6/icmp/ratelimit"}[version] 335 336 @classmethod 337 def _SetICMPRatelimit(cls, version, limit): 338 cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit) 339 340 @classmethod 341 def setUpClass(cls): 342 # This is per-class setup instead of per-testcase setup because shelling out 343 # to ip and iptables is slow, and because routing configuration doesn't 344 # change during the test. 345 cls.iproute = iproute.IPRoute() 346 cls.tuns = {} 347 cls.ifindices = {} 348 if HAVE_AUTOCONF_TABLE: 349 cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000) 350 cls.AUTOCONF_TABLE_OFFSET = -1000 351 else: 352 cls.AUTOCONF_TABLE_OFFSET = None 353 354 # Disable ICMP rate limits. These will be restored by _RestoreSysctls. 355 for version in [4, 6]: 356 cls._SetICMPRatelimit(version, 0) 357 358 for netid in cls.NETIDS: 359 cls.tuns[netid] = cls.CreateTunInterface(netid) 360 iface = cls.GetInterfaceName(netid) 361 cls.ifindices[netid] = net_test.GetInterfaceIndex(iface) 362 363 cls.SendRA(netid) 364 cls._RunSetupCommands(netid, True) 365 366 for version in [4, 6]: 367 cls.iproute.UnreachableRule(version, True, 1000) 368 369 # Uncomment to look around at interface and rule configuration while 370 # running in the background. (Once the test finishes running, all the 371 # interfaces and rules are gone.) 372 # time.sleep(30) 373 374 @classmethod 375 def tearDownClass(cls): 376 for version in [4, 6]: 377 try: 378 cls.iproute.UnreachableRule(version, False, 1000) 379 except IOError: 380 pass 381 382 for netid in cls.tuns: 383 cls._RunSetupCommands(netid, False) 384 cls.tuns[netid].close() 385 cls._RestoreSysctls() 386 387 def setUp(self): 388 self.ClearTunQueues() 389 390 def SetSocketMark(self, s, netid): 391 if netid is None: 392 netid = 0 393 s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid) 394 395 def GetSocketMark(self, s): 396 return s.getsockopt(SOL_SOCKET, net_test.SO_MARK) 397 398 def ClearSocketMark(self, s): 399 self.SetSocketMark(s, 0) 400 401 def BindToDevice(self, s, iface): 402 if not iface: 403 iface = "" 404 s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface) 405 406 def SetUnicastInterface(self, s, ifindex): 407 # Otherwise, Python thinks it's a 1-byte option. 408 ifindex = struct.pack("!I", ifindex) 409 410 # Always set the IPv4 interface, because it will be used even on IPv6 411 # sockets if the destination address is a mapped address. 412 s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex) 413 if s.family == AF_INET6: 414 s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex) 415 416 def GetRemoteAddress(self, version): 417 return {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version] 418 419 def SelectInterface(self, s, netid, mode): 420 if mode == "uid": 421 raise ValueError("Can't change UID on an existing socket") 422 elif mode == "mark": 423 self.SetSocketMark(s, netid) 424 elif mode == "oif": 425 iface = self.GetInterfaceName(netid) if netid else "" 426 self.BindToDevice(s, iface) 427 elif mode == "ucast_oif": 428 self.SetUnicastInterface(s, self.ifindices.get(netid, 0)) 429 else: 430 raise ValueError("Unknown interface selection mode %s" % mode) 431 432 def BuildSocket(self, version, constructor, netid, routing_mode): 433 uid = self.UidForNetid(netid) if routing_mode == "uid" else None 434 with net_test.RunAsUid(uid): 435 family = self.GetProtocolFamily(version) 436 s = constructor(family) 437 438 if routing_mode not in [None, "uid"]: 439 self.SelectInterface(s, netid, routing_mode) 440 441 return s 442 443 def SendOnNetid(self, version, s, dstaddr, dstport, netid, payload, cmsgs): 444 if netid is not None: 445 pktinfo = MakePktInfo(version, None, self.ifindices[netid]) 446 cmsg_level, cmsg_name = { 447 4: (net_test.SOL_IP, IP_PKTINFO), 448 6: (net_test.SOL_IPV6, IPV6_PKTINFO)}[version] 449 cmsgs.append((cmsg_level, cmsg_name, pktinfo)) 450 csocket.Sendmsg(s, (dstaddr, dstport), payload, cmsgs, csocket.MSG_CONFIRM) 451 452 def ReceiveEtherPacketOn(self, netid, packet): 453 posix.write(self.tuns[netid].fileno(), str(packet)) 454 455 def ReceivePacketOn(self, netid, ip_packet): 456 routermac = self.RouterMacAddress(netid) 457 mymac = self.MyMacAddress(netid) 458 packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet 459 self.ReceiveEtherPacketOn(netid, packet) 460 461 def ReadAllPacketsOn(self, netid, include_multicast=False): 462 packets = [] 463 while True: 464 try: 465 packet = posix.read(self.tuns[netid].fileno(), 4096) 466 if not packet: 467 break 468 ether = scapy.Ether(packet) 469 # Multicast frames are frames where the first byte of the destination 470 # MAC address has 1 in the least-significant bit. 471 if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1: 472 packets.append(ether.payload) 473 except OSError, e: 474 # EAGAIN means there are no more packets waiting. 475 if re.match(e.message, os.strerror(errno.EAGAIN)): 476 break 477 # Anything else is unexpected. 478 else: 479 raise e 480 return packets 481 482 def ClearTunQueues(self): 483 # Keep reading packets on all netids until we get no packets on any of them. 484 waiting = None 485 while waiting != 0: 486 waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS) 487 488 def assertPacketMatches(self, expected, actual): 489 # The expected packet is just a rough sketch of the packet we expect to 490 # receive. For example, it doesn't contain fields we can't predict, such as 491 # initial TCP sequence numbers, or that depend on the host implementation 492 # and settings, such as TCP options. To check whether the packet matches 493 # what we expect, instead of just checking all the known fields one by one, 494 # we blank out fields in the actual packet and then compare the whole 495 # packets to each other as strings. Because we modify the actual packet, 496 # make a copy here. 497 actual = actual.copy() 498 499 # Blank out IPv4 fields that we can't predict, like ID and the DF bit. 500 actualip = actual.getlayer("IP") 501 expectedip = expected.getlayer("IP") 502 if actualip and expectedip: 503 actualip.id = expectedip.id 504 actualip.flags &= 5 505 actualip.chksum = None # Change the header, recalculate the checksum. 506 507 # Blank out UDP fields that we can't predict (e.g., the source port for 508 # kernel-originated packets). 509 actualudp = actual.getlayer("UDP") 510 expectedudp = expected.getlayer("UDP") 511 if actualudp and expectedudp: 512 if expectedudp.sport is None: 513 actualudp.sport = None 514 actualudp.chksum = None 515 516 # Since the TCP code below messes with options, recalculate the length. 517 if actualip: 518 actualip.len = None 519 actualipv6 = actual.getlayer("IPv6") 520 if actualipv6: 521 actualipv6.plen = None 522 523 # Blank out TCP fields that we can't predict. 524 actualtcp = actual.getlayer("TCP") 525 expectedtcp = expected.getlayer("TCP") 526 if actualtcp and expectedtcp: 527 actualtcp.dataofs = expectedtcp.dataofs 528 actualtcp.options = expectedtcp.options 529 actualtcp.window = expectedtcp.window 530 if expectedtcp.sport is None: 531 actualtcp.sport = None 532 if expectedtcp.seq is None: 533 actualtcp.seq = None 534 if expectedtcp.ack is None: 535 actualtcp.ack = None 536 actualtcp.chksum = None 537 538 # Serialize the packet so that expected packet fields that are only set when 539 # a packet is serialized e.g., the checksum) are filled in. 540 expected_real = expected.__class__(str(expected)) 541 actual_real = actual.__class__(str(actual)) 542 # repr() can be expensive. Call it only if the test is going to fail and we 543 # want to see the error. 544 if expected_real != actual_real: 545 self.assertEquals(repr(expected_real), repr(actual_real)) 546 547 def PacketMatches(self, expected, actual): 548 try: 549 self.assertPacketMatches(expected, actual) 550 return True 551 except AssertionError: 552 return False 553 554 def ExpectNoPacketsOn(self, netid, msg): 555 packets = self.ReadAllPacketsOn(netid) 556 if packets: 557 firstpacket = repr(packets[0]) 558 else: 559 firstpacket = "" 560 self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket) 561 562 def ExpectPacketOn(self, netid, msg, expected): 563 # To avoid confusion due to lots of ICMPv6 ND going on all the time, drop 564 # multicast packets unless the packet we expect to see is a multicast 565 # packet. For now the only tests that use this are IPv6. 566 ipv6 = expected.getlayer("IPv6") 567 if ipv6 and ipv6.dst.startswith("ff"): 568 include_multicast = True 569 else: 570 include_multicast = False 571 572 packets = self.ReadAllPacketsOn(netid, include_multicast=include_multicast) 573 self.assertTrue(packets, msg + ": received no packets") 574 575 # If we receive a packet that matches what we expected, return it. 576 for packet in packets: 577 if self.PacketMatches(expected, packet): 578 return packet 579 580 # None of the packets matched. Call assertPacketMatches to output a diff 581 # between the expected packet and the last packet we received. In theory, 582 # we'd output a diff to the packet that's the best match for what we 583 # expected, but this is good enough for now. 584 try: 585 self.assertPacketMatches(expected, packets[-1]) 586 except Exception, e: 587 raise UnexpectedPacketError( 588 "%s: diff with last packet:\n%s" % (msg, e.message)) 589 590 def Combinations(self, version): 591 """Produces a list of combinations to test.""" 592 combinations = [] 593 594 # Check packets addressed to the IP addresses of all our interfaces... 595 for dest_ip_netid in self.tuns: 596 ip_if = self.GetInterfaceName(dest_ip_netid) 597 myaddr = self.MyAddress(version, dest_ip_netid) 598 remoteaddr = self.GetRemoteAddress(version) 599 600 # ... coming in on all our interfaces. 601 for netid in self.tuns: 602 iif = self.GetInterfaceName(netid) 603 combinations.append((netid, iif, ip_if, myaddr, remoteaddr)) 604 605 return combinations 606 607 def _FormatMessage(self, iif, ip_if, extra, desc, reply_desc): 608 msg = "Receiving %s on %s to %s IP, %s" % (desc, iif, ip_if, extra) 609 if reply_desc: 610 msg += ": Expecting %s on %s" % (reply_desc, iif) 611 else: 612 msg += ": Expecting no packets on %s" % iif 613 return msg 614 615 def _ReceiveAndExpectResponse(self, netid, packet, reply, msg): 616 self.ReceivePacketOn(netid, packet) 617 if reply: 618 return self.ExpectPacketOn(netid, msg, reply) 619 else: 620 self.ExpectNoPacketsOn(netid, msg) 621 return None 622