1#!/usr/bin/env python
2# Copyright (c) 2017 Facebook, Inc.
3# Licensed under the Apache License, Version 2.0 (the "License")
4
5import ctypes as ct
6import unittest
7from bcc import BPF
8from netaddr import IPAddress
9
10class KeyV4(ct.Structure):
11    _fields_ = [("prefixlen", ct.c_uint),
12                ("data", ct.c_ubyte * 4)]
13
14class KeyV6(ct.Structure):
15    _fields_ = [("prefixlen", ct.c_uint),
16                ("data", ct.c_ushort * 8)]
17
18class TestLpmTrie(unittest.TestCase):
19    def test_lpm_trie_v4(self):
20        test_prog1 = """
21        BPF_LPM_TRIE(trie, u64, int, 16);
22        """
23        b = BPF(text=test_prog1)
24        t = b["trie"]
25
26        k1 = KeyV4(24, (192, 168, 0, 0))
27        v1 = ct.c_int(24)
28        t[k1] = v1
29
30        k2 = KeyV4(28, (192, 168, 0, 0))
31        v2 = ct.c_int(28)
32        t[k2] = v2
33
34        k = KeyV4(32, (192, 168, 0, 15))
35        self.assertEqual(t[k].value, 28)
36
37        k = KeyV4(32, (192, 168, 0, 127))
38        self.assertEqual(t[k].value, 24)
39
40        with self.assertRaises(KeyError):
41            k = KeyV4(32, (172, 16, 1, 127))
42            v = t[k]
43
44    def test_lpm_trie_v6(self):
45        test_prog1 = """
46        struct key_v6 {
47            u32 prefixlen;
48            u32 data[4];
49        };
50        BPF_LPM_TRIE(trie, struct key_v6, int, 16);
51        """
52        b = BPF(text=test_prog1)
53        t = b["trie"]
54
55        k1 = KeyV6(64, IPAddress('2a00:1450:4001:814:200e::').words)
56        v1 = ct.c_int(64)
57        t[k1] = v1
58
59        k2 = KeyV6(96, IPAddress('2a00:1450:4001:814::200e').words)
60        v2 = ct.c_int(96)
61        t[k2] = v2
62
63        k = KeyV6(128, IPAddress('2a00:1450:4001:814::1024').words)
64        self.assertEqual(t[k].value, 96)
65
66        k = KeyV6(128, IPAddress('2a00:1450:4001:814:2046::').words)
67        self.assertEqual(t[k].value, 64)
68
69        with self.assertRaises(KeyError):
70            k = KeyV6(128, IPAddress('2a00:ffff::').words)
71            v = t[k]
72
73if __name__ == "__main__":
74    unittest.main()
75