1 /*
2  * Copyright (c) 2017 Politecnico di Torino
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "BPF.h"
18 
19 #include "catch.hpp"
20 
21 #include <random>
22 #include <iostream>
23 
24 #include <linux/version.h>
25 
26 TEST_CASE("test array table", "[array_table]") {
27   const std::string BPF_PROGRAM = R"(
28     BPF_TABLE("hash", int, int, myhash, 128);
29     BPF_TABLE("array", int, int, myarray, 128);
30   )";
31 
32   // turn off the rw_engine
33   ebpf::BPF bpf(0, nullptr, false);
34   ebpf::StatusTuple res(0);
35   res = bpf.init(BPF_PROGRAM);
36   REQUIRE(res.code() == 0);
37 
38   ebpf::BPFArrayTable<int> t = bpf.get_array_table<int>("myarray");
39 
40   SECTION("bad table type") {
41     // try to get table of wrong type
__anonda43bfd80102()42     auto f1 = [&](){
43       bpf.get_array_table<int>("myhash");
44     };
45 
46     REQUIRE_THROWS(f1());
47   }
48 
49   SECTION("standard methods") {
50     int i, v1, v2;
51     i = 1;
52     v1 = 42;
53     // update element
54     res = t.update_value(i, v1);
55     REQUIRE(res.code() == 0);
56     res = t.get_value(i, v2);
57     REQUIRE(res.code() == 0);
58     REQUIRE(v2 == 42);
59 
60     // update another element
61     i = 2;
62     v1 = 69;
63     res = t.update_value(i, v1);
64     REQUIRE(res.code() == 0);
65     res = t.get_value(i, v2);
66     REQUIRE(res.code() == 0);
67     REQUIRE(v2 == 69);
68 
69     // get non existing element
70     i = 1024;
71     res = t.get_value(i, v2);
72     REQUIRE(res.code() != 0);
73   }
74 
75   SECTION("full table") {
76     // random number generator
77     std::mt19937 rng;
78     rng.seed(std::random_device()());
79     std::uniform_int_distribution<int> dist;
80 
81     std::vector<int> localtable(128);
82 
83     for(int i = 0; i < 128; i++) {
84       int v = dist(rng);
85 
86       res = t.update_value(i, v);
87       REQUIRE(res.code() == 0);
88 
89       // save it in the local table to compare later on
90       localtable[i] = v;
91     }
92 
93     std::vector<int> offlinetable = t.get_table_offline();
94     REQUIRE(localtable == offlinetable);
95   }
96 }
97 
98 #if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 6, 0)
99 TEST_CASE("percpu array table", "[percpu_array_table]") {
100   const std::string BPF_PROGRAM = R"(
101     BPF_TABLE("percpu_hash", int, u64, myhash, 128);
102     BPF_TABLE("percpu_array", int, u64, myarray, 64);
103   )";
104 
105   ebpf::BPF bpf;
106   ebpf::StatusTuple res(0);
107   res = bpf.init(BPF_PROGRAM);
108   REQUIRE(res.code() == 0);
109 
110   ebpf::BPFPercpuArrayTable<uint64_t> t = bpf.get_percpu_array_table<uint64_t>("myarray");
111   size_t ncpus = ebpf::BPFTable::get_possible_cpu_count();
112 
113   SECTION("bad table type") {
114     // try to get table of wrong type
__anonda43bfd80202()115     auto f1 = [&](){
116       bpf.get_percpu_array_table<uint64_t>("myhash");
117     };
118 
119     REQUIRE_THROWS(f1());
120   }
121 
122   SECTION("standard methods") {
123     int i;
124     std::vector<uint64_t> v1(ncpus);
125     std::vector<uint64_t> v2;
126 
127     for (size_t j = 0; j < ncpus; j++) {
128       v1[j] = 42 * j;
129     }
130 
131     i = 1;
132     // update element
133     res = t.update_value(i, v1);
134     REQUIRE(res.code() == 0);
135     res = t.get_value(i, v2);
136     REQUIRE(res.code() == 0);
137     REQUIRE(v2.size() == ncpus);
138     for (size_t j = 0; j < ncpus; j++) {
139       REQUIRE(v2.at(j) == 42 * j);
140     }
141 
142     // get non existing element
143     i = 1024;
144     res = t.get_value(i, v2);
145     REQUIRE(res.code() != 0);
146   }
147 }
148 #endif
149