1#!/usr/bin/python
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import itertools
18import random
19import unittest
20
21from socket import *
22
23import iproute
24import multinetwork_base
25import net_test
26import packets
27
28
29class ForwardingTest(multinetwork_base.MultiNetworkBaseTest):
30  """Checks that IPv6 forwarding doesn't crash the system.
31
32  Relevant kernel commits:
33    upstream net-next:
34      e7eadb4 ipv6: inet6_sk() should use sk_fullsock()
35    android-3.10:
36      feee3c1 ipv6: inet6_sk() should use sk_fullsock()
37      cdab04e net: add sk_fullsock() helper
38    android-3.18:
39      8246f18 ipv6: inet6_sk() should use sk_fullsock()
40      bea19db net: add sk_fullsock() helper
41  """
42
43  TCP_TIME_WAIT = 6
44
45  def ForwardBetweenInterfaces(self, enabled, iface1, iface2):
46    for iif, oif in itertools.permutations([iface1, iface2]):
47      self.iproute.IifRule(6, enabled, self.GetInterfaceName(iif),
48                           self._TableForNetid(oif), self.PRIORITY_IIF)
49
50  def setUp(self):
51    self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 1)
52
53  def tearDown(self):
54    self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 0)
55
56  def CheckForwardingCrash(self, netid, iface1, iface2):
57    version = 6
58    listensocket = net_test.IPv6TCPSocket()
59    self.SetSocketMark(listensocket, netid)
60    listenport = net_test.BindRandomPort(version, listensocket)
61
62    remoteaddr = self.GetRemoteAddress(version)
63    myaddr = self.MyAddress(version, netid)
64
65    desc, syn = packets.SYN(listenport, version, remoteaddr, myaddr)
66    synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
67    msg = "Sent %s, expected %s" % (desc, synack_desc)
68    reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
69
70    establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
71    self.ReceivePacketOn(netid, establishing_ack)
72    accepted, peer = listensocket.accept()
73    remoteport = accepted.getpeername()[1]
74
75    accepted.close()
76    desc, fin = packets.FIN(version, myaddr, remoteaddr, establishing_ack)
77    self.ExpectPacketOn(netid, msg + ": expecting %s after close" % desc, fin)
78
79    desc, finack = packets.FIN(version, remoteaddr, myaddr, fin)
80    self.ReceivePacketOn(netid, finack)
81
82    # Check our socket is now in TIME_WAIT.
83    sockets = self.ReadProcNetSocket("tcp6")
84    mysrc = "%s:%04X" % (net_test.FormatSockStatAddress(myaddr), listenport)
85    mydst = "%s:%04X" % (net_test.FormatSockStatAddress(remoteaddr), remoteport)
86    state = None
87    sockets = [s for s in sockets if s[0] == mysrc and s[1] == mydst]
88    self.assertEquals(1, len(sockets))
89    self.assertEquals("%02X" % self.TCP_TIME_WAIT, sockets[0][2])
90
91    # Remove our IP address.
92    try:
93      self.iproute.DelAddress(myaddr, 64, self.ifindices[netid])
94
95      self.ReceivePacketOn(iface1, finack)
96      self.ReceivePacketOn(iface1, establishing_ack)
97      self.ReceivePacketOn(iface1, establishing_ack)
98      # No crashes? Good.
99
100    finally:
101      # Put back our IP address.
102      self.SendRA(netid)
103      listensocket.close()
104
105  def testCrash(self):
106    # Run the test a few times as it doesn't crash/hang the first time.
107    for netids in itertools.permutations(self.tuns):
108      # Pick an interface to send traffic on and two to forward traffic between.
109      netid, iface1, iface2 = random.sample(netids, 3)
110      self.ForwardBetweenInterfaces(True, iface1, iface2)
111      try:
112        self.CheckForwardingCrash(netid, iface1, iface2)
113      finally:
114        self.ForwardBetweenInterfaces(False, iface1, iface2)
115
116
117if __name__ == "__main__":
118  unittest.main()
119