1#!/usr/bin/python
2#
3# Copyright 2019 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import unittest
18
19from errno import *  # pylint: disable=wildcard-import
20from socket import *  # pylint: disable=wildcard-import
21import ctypes
22import fcntl
23import os
24import random
25import select
26import termios
27import threading
28import time
29from scapy import all as scapy
30
31import multinetwork_base
32import net_test
33import packets
34
35SOL_TCP = net_test.SOL_TCP
36SHUT_RD = net_test.SHUT_RD
37SHUT_WR = net_test.SHUT_WR
38SHUT_RDWR = net_test.SHUT_RDWR
39SIOCINQ = termios.FIONREAD
40SIOCOUTQ = termios.TIOCOUTQ
41
42TEST_PORT = 5555
43
44# Following constants are SOL_TCP level options and arguments.
45# They are defined in linux-kernel: include/uapi/linux/tcp.h
46
47# SOL_TCP level options.
48TCP_REPAIR = 19
49TCP_REPAIR_QUEUE = 20
50TCP_QUEUE_SEQ = 21
51
52# TCP_REPAIR_{OFF, ON} is an argument to TCP_REPAIR.
53TCP_REPAIR_OFF = 0
54TCP_REPAIR_ON = 1
55
56# TCP_{NO, RECV, SEND}_QUEUE is an argument to TCP_REPAIR_QUEUE.
57TCP_NO_QUEUE = 0
58TCP_RECV_QUEUE = 1
59TCP_SEND_QUEUE = 2
60
61# This test is aiming to ensure tcp keep alive offload works correctly
62# when it fetches tcp information from kernel via tcp repair mode.
63class TcpRepairTest(multinetwork_base.MultiNetworkBaseTest):
64
65  def assertSocketNotConnected(self, sock):
66    self.assertRaisesErrno(ENOTCONN, sock.getpeername)
67
68  def assertSocketConnected(self, sock):
69    sock.getpeername()  # No errors? Socket is alive and connected.
70
71  def createConnectedSocket(self, version, netid):
72    s = net_test.TCPSocket(net_test.GetAddressFamily(version))
73    net_test.DisableFinWait(s)
74    self.SelectInterface(s, netid, "mark")
75
76    remotesockaddr = self.GetRemoteSocketAddress(version)
77    remoteaddr = self.GetRemoteAddress(version)
78    self.assertRaisesErrno(EINPROGRESS, s.connect, (remotesockaddr, TEST_PORT))
79    self.assertSocketNotConnected(s)
80
81    myaddr = self.MyAddress(version, netid)
82    port = s.getsockname()[1]
83    self.assertNotEqual(0, port)
84
85    desc, expect_syn = packets.SYN(TEST_PORT, version, myaddr, remoteaddr, port, seq=None)
86    msg = "socket connect: expected %s" % desc
87    syn = self.ExpectPacketOn(netid, msg, expect_syn)
88    synack_desc, synack = packets.SYNACK(version, remoteaddr, myaddr, syn)
89    synack.getlayer("TCP").seq = random.getrandbits(32)
90    synack.getlayer("TCP").window = 14400
91    self.ReceivePacketOn(netid, synack)
92    desc, ack = packets.ACK(version, myaddr, remoteaddr, synack)
93    msg = "socket connect: got SYN+ACK, expected %s" % desc
94    ack = self.ExpectPacketOn(netid, msg, ack)
95    self.last_sent = ack
96    self.last_received = synack
97    return s
98
99  def receiveFin(self, netid, version, sock):
100    self.assertSocketConnected(sock)
101    remoteaddr = self.GetRemoteAddress(version)
102    myaddr = self.MyAddress(version, netid)
103    desc, fin = packets.FIN(version, remoteaddr, myaddr, self.last_sent)
104    self.ReceivePacketOn(netid, fin)
105    self.last_received = fin
106
107  def sendData(self, netid, version, sock, payload):
108    sock.send(payload)
109
110    remoteaddr = self.GetRemoteAddress(version)
111    myaddr = self.MyAddress(version, netid)
112    desc, send = packets.ACK(version, myaddr, remoteaddr,
113                             self.last_received, payload)
114    self.last_sent = send
115
116  def receiveData(self, netid, version, payload):
117    remoteaddr = self.GetRemoteAddress(version)
118    myaddr = self.MyAddress(version, netid)
119
120    desc, received = packets.ACK(version, remoteaddr, myaddr,
121                                 self.last_sent, payload)
122    ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, received)
123    self.ReceivePacketOn(netid, received)
124    time.sleep(0.1)
125    self.ExpectPacketOn(netid, "expecting %s" % ack_desc, ack)
126    self.last_sent = ack
127    self.last_received = received
128
129  # Test the behavior of NO_QUEUE. Expect incoming data will be stored into
130  # the queue, but socket cannot be read/written in NO_QUEUE.
131  def testTcpRepairInNoQueue(self):
132    for version in [4, 5, 6]:
133      self.tcpRepairInNoQueueTest(version)
134
135  def tcpRepairInNoQueueTest(self, version):
136    netid = self.RandomNetid()
137    sock = self.createConnectedSocket(version, netid)
138    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON)
139
140    # In repair mode with NO_QUEUE, writes fail...
141    self.assertRaisesErrno(EINVAL, sock.send, "write test")
142
143    # remote data is coming.
144    TEST_RECEIVED = net_test.UDP_PAYLOAD
145    self.receiveData(netid, version, TEST_RECEIVED)
146
147    # In repair mode with NO_QUEUE, read fail...
148    self.assertRaisesErrno(EPERM, sock.recv, 4096)
149
150    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF)
151    readData = sock.recv(4096)
152    self.assertEqual(readData, TEST_RECEIVED)
153    sock.close()
154
155  # Test whether tcp read/write sequence number can be fetched correctly
156  # by TCP_QUEUE_SEQ.
157  def testGetSequenceNumber(self):
158    for version in [4, 5, 6]:
159      self.GetSequenceNumberTest(version)
160
161  def GetSequenceNumberTest(self, version):
162    netid = self.RandomNetid()
163    sock = self.createConnectedSocket(version, netid)
164    # test write queue sequence number
165    sequence_before = self.GetWriteSequenceNumber(version, sock)
166    expect_sequence = self.last_sent.getlayer("TCP").seq
167    self.assertEqual(sequence_before & 0xffffffff, expect_sequence)
168    TEST_SEND = net_test.UDP_PAYLOAD
169    self.sendData(netid, version, sock, TEST_SEND)
170    sequence_after = self.GetWriteSequenceNumber(version, sock)
171    self.assertEqual(sequence_before + len(TEST_SEND), sequence_after)
172
173    # test read queue sequence number
174    sequence_before = self.GetReadSequenceNumber(version, sock)
175    expect_sequence = self.last_received.getlayer("TCP").seq + 1
176    self.assertEqual(sequence_before & 0xffffffff, expect_sequence)
177    TEST_READ = net_test.UDP_PAYLOAD
178    self.receiveData(netid, version, TEST_READ)
179    sequence_after = self.GetReadSequenceNumber(version, sock)
180    self.assertEqual(sequence_before + len(TEST_READ), sequence_after)
181    sock.close()
182
183  def GetWriteSequenceNumber(self, version, sock):
184    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON)
185    sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_SEND_QUEUE)
186    sequence = sock.getsockopt(SOL_TCP, TCP_QUEUE_SEQ)
187    sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_NO_QUEUE)
188    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF)
189    return sequence
190
191  def GetReadSequenceNumber(self, version, sock):
192    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON)
193    sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_RECV_QUEUE)
194    sequence = sock.getsockopt(SOL_TCP, TCP_QUEUE_SEQ)
195    sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_NO_QUEUE)
196    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF)
197    return sequence
198
199  # Test whether tcp repair socket can be poll()'ed correctly
200  # in mutiple threads at the same time.
201  def testMultiThreadedPoll(self):
202    for version in [4, 5, 6]:
203      self.PollWhenShutdownTest(version)
204      self.PollWhenReceiveFinTest(version)
205
206  def PollRepairSocketInMultipleThreads(self, netid, version, expected):
207    sock = self.createConnectedSocket(version, netid)
208    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON)
209
210    multiThreads = []
211    for i in [0, 1]:
212      thread = SocketExceptionThread(sock, lambda sk: self.fdSelect(sock, expected))
213      thread.start()
214      self.assertTrue(thread.is_alive())
215      multiThreads.append(thread)
216
217    return sock, multiThreads
218
219  def assertThreadsStopped(self, multiThreads, msg) :
220    for thread in multiThreads:
221      if (thread.is_alive()):
222        thread.join(1)
223      if (thread.is_alive()):
224        thread.stop()
225        raise AssertionError(msg)
226
227  def PollWhenShutdownTest(self, version):
228    netid = self.RandomNetid()
229    expected = select.POLLIN
230    sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected)
231    # Test shutdown RD.
232    sock.shutdown(SHUT_RD)
233    self.assertThreadsStopped(multiThreads, "poll fail during SHUT_RD")
234    sock.close()
235
236    expected = None
237    sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected)
238    # Test shutdown WR.
239    sock.shutdown(SHUT_WR)
240    self.assertThreadsStopped(multiThreads, "poll fail during SHUT_WR")
241    sock.close()
242
243    expected = select.POLLIN | select.POLLHUP
244    sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected)
245    # Test shutdown RDWR.
246    sock.shutdown(SHUT_RDWR)
247    self.assertThreadsStopped(multiThreads, "poll fail during SHUT_RDWR")
248    sock.close()
249
250  def PollWhenReceiveFinTest(self, version):
251    netid = self.RandomNetid()
252    expected = select.POLLIN
253    sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected)
254    self.receiveFin(netid, version, sock)
255    self.assertThreadsStopped(multiThreads, "poll fail during FIN")
256    sock.close()
257
258  # Test whether socket idle can be detected by SIOCINQ and SIOCOUTQ.
259  def testSocketIdle(self):
260    for version in [4, 5, 6]:
261      self.readQueueIdleTest(version)
262      self.writeQueueIdleTest(version)
263
264  def readQueueIdleTest(self, version):
265    netid = self.RandomNetid()
266    sock = self.createConnectedSocket(version, netid)
267
268    buf = ctypes.c_int()
269    fcntl.ioctl(sock, SIOCINQ, buf)
270    self.assertEqual(buf.value, 0)
271
272    TEST_RECV_PAYLOAD = net_test.UDP_PAYLOAD
273    self.receiveData(netid, version, TEST_RECV_PAYLOAD)
274    fcntl.ioctl(sock, SIOCINQ, buf)
275    self.assertEqual(buf.value, len(TEST_RECV_PAYLOAD))
276    sock.close()
277
278  def writeQueueIdleTest(self, version):
279    netid = self.RandomNetid()
280    # Setup a connected socket, write queue is empty.
281    sock = self.createConnectedSocket(version, netid)
282    buf = ctypes.c_int()
283    fcntl.ioctl(sock, SIOCOUTQ, buf)
284    self.assertEqual(buf.value, 0)
285    # Change to repair mode with SEND_QUEUE, writing some data to the queue.
286    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON)
287    TEST_SEND_PAYLOAD = net_test.UDP_PAYLOAD
288    sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_SEND_QUEUE)
289    self.sendData(netid, version, sock, TEST_SEND_PAYLOAD)
290    fcntl.ioctl(sock, SIOCOUTQ, buf)
291    self.assertEqual(buf.value, len(TEST_SEND_PAYLOAD))
292    sock.close()
293
294    # Setup a connected socket again.
295    netid = self.RandomNetid()
296    sock = self.createConnectedSocket(version, netid)
297    # Send out some data and don't receive ACK yet.
298    self.sendData(netid, version, sock, TEST_SEND_PAYLOAD)
299    fcntl.ioctl(sock, SIOCOUTQ, buf)
300    self.assertEqual(buf.value, len(TEST_SEND_PAYLOAD))
301    # Receive response ACK.
302    remoteaddr = self.GetRemoteAddress(version)
303    myaddr = self.MyAddress(version, netid)
304    desc_ack, ack = packets.ACK(version, remoteaddr, myaddr, self.last_sent)
305    self.ReceivePacketOn(netid, ack)
306    fcntl.ioctl(sock, SIOCOUTQ, buf)
307    self.assertEqual(buf.value, 0)
308    sock.close()
309
310
311  def fdSelect(self, sock, expected):
312    READ_ONLY = select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR | select.POLLNVAL
313    p = select.poll()
314    p.register(sock, READ_ONLY)
315    events = p.poll(500)
316    for fd,event in events:
317      if fd == sock.fileno():
318        self.assertEqual(event, expected)
319      else:
320        raise AssertionError("unexpected poll fd")
321
322class SocketExceptionThread(threading.Thread):
323
324  def __init__(self, sock, operation):
325    self.exception = None
326    super(SocketExceptionThread, self).__init__()
327    self.daemon = True
328    self.sock = sock
329    self.operation = operation
330
331  def stop(self):
332    self._Thread__stop()
333
334  def run(self):
335    try:
336      self.operation(self.sock)
337    except (IOError, AssertionError) as e:
338      self.exception = e
339
340if __name__ == '__main__':
341  unittest.main()
342