1#!/usr/bin/env python
2# Copyright (c) PLUMgrid, Inc.
3# Licensed under the Apache License, Version 2.0 (the "License")
4
5from ctypes import c_ushort, c_int, c_ulonglong
6from netaddr import IPAddress
7from bcc import BPF
8from pyroute2 import IPRoute
9from socket import socket, AF_INET, SOCK_DGRAM
10import sys
11from time import sleep
12from unittest import main, TestCase
13
14arg1 = sys.argv.pop(1)
15
16S_EOP = 1
17S_ETHER = 2
18S_ARP = 3
19S_IP = 4
20
21class TestBPFSocket(TestCase):
22    def setUp(self):
23        b = BPF(src_file=arg1, debug=0)
24        ether_fn = b.load_func("parse_ether", BPF.SCHED_CLS)
25        arp_fn = b.load_func("parse_arp", BPF.SCHED_CLS)
26        ip_fn = b.load_func("parse_ip", BPF.SCHED_CLS)
27        eop_fn = b.load_func("eop", BPF.SCHED_CLS)
28        ip = IPRoute()
29        ifindex = ip.link_lookup(ifname="eth0")[0]
30        ip.tc("add", "sfq", ifindex, "1:")
31        ip.tc("add-filter", "bpf", ifindex, ":1", fd=ether_fn.fd,
32              name=ether_fn.name, parent="1:", action="ok", classid=1)
33        self.jump = b.get_table("jump", c_int, c_int)
34        self.jump[c_int(S_ARP)] = c_int(arp_fn.fd)
35        self.jump[c_int(S_IP)] = c_int(ip_fn.fd)
36        self.jump[c_int(S_EOP)] = c_int(eop_fn.fd)
37        self.stats = b.get_table("stats", c_int, c_ulonglong)
38
39    def test_jumps(self):
40        udp = socket(AF_INET, SOCK_DGRAM)
41        udp.sendto(b"a" * 10, ("172.16.1.1", 5000))
42        udp.close()
43        self.assertGreater(self.stats[c_int(S_IP)].value, 0)
44        self.assertGreater(self.stats[c_int(S_ARP)].value, 0)
45        self.assertGreater(self.stats[c_int(S_EOP)].value, 1)
46
47if __name__ == "__main__":
48    main()
49