1#!/usr/bin/env python 2# Copyright (c) PLUMgrid, Inc. 3# Licensed under the Apache License, Version 2.0 (the "License") 4 5# test program to count the packets sent to a device in a .5 6# second period 7 8from ctypes import c_uint, c_ulong, Structure 9from netaddr import IPAddress 10from bcc import BPF 11from subprocess import check_call 12import sys 13from unittest import main, TestCase 14 15arg1 = sys.argv.pop(1) 16arg2 = "" 17if len(sys.argv) > 1: 18 arg2 = sys.argv.pop(1) 19 20Key = None 21Leaf = None 22if arg1.endswith(".b"): 23 class Key(Structure): 24 _fields_ = [("dip", c_uint), 25 ("sip", c_uint)] 26 class Leaf(Structure): 27 _fields_ = [("rx_pkts", c_ulong), 28 ("tx_pkts", c_ulong)] 29 30class TestBPFSocket(TestCase): 31 def setUp(self): 32 b = BPF(arg1, arg2, debug=0) 33 fn = b.load_func("on_packet", BPF.SOCKET_FILTER) 34 BPF.attach_raw_socket(fn, "eth0") 35 self.stats = b.get_table("stats", Key, Leaf) 36 37 def test_ping(self): 38 cmd = ["ping", "-f", "-c", "100", "172.16.1.1"] 39 check_call(cmd) 40 #for key, leaf in self.stats.items(): 41 # print(IPAddress(key.sip), "=>", IPAddress(key.dip), 42 # "rx", leaf.rx_pkts, "tx", leaf.tx_pkts) 43 key = self.stats.Key(IPAddress("172.16.1.2").value, IPAddress("172.16.1.1").value) 44 leaf = self.stats[key] 45 self.assertEqual(leaf.rx_pkts, 100) 46 self.assertEqual(leaf.tx_pkts, 100) 47 del self.stats[key] 48 with self.assertRaises(KeyError): 49 x = self.stats[key] 50 with self.assertRaises(KeyError): 51 del self.stats[key] 52 self.stats.clear() 53 self.assertEqual(len(self.stats), 0) 54 self.stats[key] = leaf 55 self.assertEqual(len(self.stats), 1) 56 self.stats.clear() 57 self.assertEqual(len(self.stats), 0) 58 59 def test_empty_key(self): 60 # test with a 0 key 61 self.stats.clear() 62 self.stats[self.stats.Key()] = self.stats.Leaf(100, 200) 63 x = self.stats.popitem() 64 self.stats[self.stats.Key(10, 20)] = self.stats.Leaf(300, 400) 65 with self.assertRaises(KeyError): 66 x = self.stats[self.stats.Key()] 67 (_, x) = self.stats.popitem() 68 self.assertEqual(x.rx_pkts, 300) 69 self.assertEqual(x.tx_pkts, 400) 70 self.stats.clear() 71 self.assertEqual(len(self.stats), 0) 72 self.stats[self.stats.Key()] = x 73 self.stats[self.stats.Key(0, 1)] = x 74 self.stats[self.stats.Key(0, 2)] = x 75 self.stats[self.stats.Key(0, 3)] = x 76 self.assertEqual(len(self.stats), 4) 77 78if __name__ == "__main__": 79 main() 80