1#!/usr/bin/python
2#
3# Copyright 2016 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import ctypes
18import errno
19import socket
20import unittest
21
22from bpf import *  # pylint: disable=wildcard-import
23import csocket
24import net_test
25
26libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
27HAVE_EBPF_SUPPORT = net_test.LINUX_VERSION >= (4, 4, 0)
28
29@unittest.skipUnless(HAVE_EBPF_SUPPORT,
30                     "eBPF function not fully supported")
31class BpfTest(net_test.NetworkTest):
32
33  def testCreateMap(self):
34    key, value = 1, 1
35    map_fd = CreateMap(BPF_MAP_TYPE_HASH, 4, 4, 100)
36    UpdateMap(map_fd, key, value)
37    self.assertEquals(LookupMap(map_fd, key).value, value)
38    DeleteMap(map_fd, key)
39    self.assertRaisesErrno(errno.ENOENT, LookupMap, map_fd, key)
40
41  def testIterateMap(self):
42    map_fd = CreateMap(BPF_MAP_TYPE_HASH, 4, 4, 100)
43    value = 1024
44    for key in xrange(1, 100):
45      UpdateMap(map_fd, key, value)
46    for key in xrange(1, 100):
47      self.assertEquals(LookupMap(map_fd, key).value, value)
48    self.assertRaisesErrno(errno.ENOENT, LookupMap, map_fd, 101)
49    key = 0
50    count = 0
51    while 1:
52      if count == 99:
53        self.assertRaisesErrno(errno.ENOENT, GetNextKey, map_fd, key)
54        break
55      else:
56        result = GetNextKey(map_fd, key)
57        key = result.value
58        self.assertGreater(key, 0)
59        self.assertEquals(LookupMap(map_fd, key).value, value)
60        count += 1
61
62  def testProgLoad(self):
63    bpf_prog = BpfMov64Reg(BPF_REG_6, BPF_REG_1)
64    bpf_prog += BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0)
65    bpf_prog += BpfExitInsn()
66    insn_buff = ctypes.create_string_buffer(bpf_prog)
67    # Load a program that does nothing except pass every packet it receives
68    # It should not block the packet transmission otherwise the test fails.
69    prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER,
70                          ctypes.addressof(insn_buff),
71                          len(insn_buff), BpfInsn._length)
72    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
73    sock.settimeout(1)
74    BpfProgAttach(sock.fileno(), prog_fd)
75    addr = "127.0.0.1"
76    sock.bind((addr, 0))
77    addr = sock.getsockname()
78    sockaddr = csocket.Sockaddr(addr)
79    sock.sendto("foo", addr)
80    data, addr = csocket.Recvfrom(sock, 4096, 0)
81    self.assertEqual("foo", data)
82    self.assertEqual(sockaddr, addr)
83
84  def testPacketBlock(self):
85    bpf_prog = BpfMov64Reg(BPF_REG_6, BPF_REG_1)
86    bpf_prog += BpfMov64Imm(BPF_REG_0, 0)
87    bpf_prog += BpfExitInsn()
88    insn_buff = ctypes.create_string_buffer(bpf_prog)
89    prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER,
90                          ctypes.addressof(insn_buff),
91                          len(insn_buff), BpfInsn._length)
92    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
93    sock.settimeout(1)
94    BpfProgAttach(sock.fileno(), prog_fd)
95    addr = "127.0.0.1"
96    sock.bind((addr, 0))
97    addr = sock.getsockname()
98    sock.sendto("foo", addr)
99    self.assertRaisesErrno(errno.EAGAIN, csocket.Recvfrom, sock, 4096, 0)
100
101  def testPacketCount(self):
102    map_fd = CreateMap(BPF_MAP_TYPE_HASH, 4, 4, 100)
103    key = 0xf0f0
104    bpf_prog = BpfMov64Reg(BPF_REG_6, BPF_REG_1)
105    bpf_prog += BpfLoadMapFd(map_fd, BPF_REG_1)
106    bpf_prog += BpfMov64Imm(BPF_REG_7, key)
107    bpf_prog += BpfStxMem(BPF_W, BPF_REG_10, BPF_REG_7, -4)
108    bpf_prog += BpfMov64Reg(BPF_REG_8, BPF_REG_10)
109    bpf_prog += BpfAlu64Imm(BPF_ADD, BPF_REG_8, -4)
110    bpf_prog += BpfMov64Reg(BPF_REG_2, BPF_REG_8)
111    bpf_prog += BpfFuncLookupMap()
112    bpf_prog += BpfJumpImm(BPF_AND, BPF_REG_0, 0, 10)
113    bpf_prog += BpfLoadMapFd(map_fd, BPF_REG_1)
114    bpf_prog += BpfMov64Reg(BPF_REG_2, BPF_REG_8)
115    bpf_prog += BpfStMem(BPF_W, BPF_REG_10, -8, 1)
116    bpf_prog += BpfMov64Reg(BPF_REG_3, BPF_REG_10)
117    bpf_prog += BpfAlu64Imm(BPF_ADD, BPF_REG_3, -8)
118    bpf_prog += BpfMov64Imm(BPF_REG_4, 0)
119    bpf_prog += BpfFuncUpdateMap()
120    bpf_prog += BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0)
121    bpf_prog += BpfExitInsn()
122    bpf_prog += BpfMov64Reg(BPF_REG_2, BPF_REG_0)
123    bpf_prog += BpfMov64Imm(BPF_REG_1, 1)
124    bpf_prog += BpfRawInsn(BPF_STX | BPF_XADD | BPF_W, BPF_REG_2, BPF_REG_1,
125                           0, 0)
126    bpf_prog += BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0)
127    bpf_prog += BpfExitInsn()
128    insn_buff = ctypes.create_string_buffer(bpf_prog)
129    # this program loaded is used to counting the packet transmitted through
130    # a target socket. It will store the packet count into the eBPF map and we
131    # will verify if the counting result is correct.
132    prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER,
133                          ctypes.addressof(insn_buff),
134                          len(insn_buff), BpfInsn._length)
135    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
136    sock.settimeout(1)
137    BpfProgAttach(sock.fileno(), prog_fd)
138    addr = "127.0.0.1"
139    sock.bind((addr, 0))
140    addr = sock.getsockname()
141    sockaddr = csocket.Sockaddr(addr)
142    packet_count = 100
143    for i in xrange(packet_count):
144      sock.sendto("foo", addr)
145      data, retaddr = csocket.Recvfrom(sock, 4096, 0)
146      self.assertEqual("foo", data)
147      self.assertEqual(sockaddr, retaddr)
148    self.assertEquals(LookupMap(map_fd, key).value, packet_count)
149
150
151if __name__ == "__main__":
152  unittest.main()
153