1#!/usr/bin/env python
2#
3# xdp_drop_count.py Drop incoming packets on XDP layer and count for which
4#                   protocol type
5#
6# Copyright (c) 2016 PLUMgrid
7# Copyright (c) 2016 Jan Ruth
8# Licensed under the Apache License, Version 2.0 (the "License")
9
10from bcc import BPF
11import pyroute2
12import time
13import sys
14
15flags = 0
16def usage():
17    print("Usage: {0} [-S] <ifdev>".format(sys.argv[0]))
18    print("       -S: use skb mode\n")
19    print("e.g.: {0} eth0\n".format(sys.argv[0]))
20    exit(1)
21
22if len(sys.argv) < 2 or len(sys.argv) > 3:
23    usage()
24
25if len(sys.argv) == 2:
26    device = sys.argv[1]
27
28if len(sys.argv) == 3:
29    if "-S" in sys.argv:
30        # XDP_FLAGS_SKB_MODE
31        flags |= 2 << 0
32
33    if "-S" == sys.argv[1]:
34        device = sys.argv[2]
35    else:
36        device = sys.argv[1]
37
38mode = BPF.XDP
39#mode = BPF.SCHED_CLS
40
41if mode == BPF.XDP:
42    ret = "XDP_DROP"
43    ctxtype = "xdp_md"
44else:
45    ret = "TC_ACT_SHOT"
46    ctxtype = "__sk_buff"
47
48# load BPF program
49b = BPF(text = """
50#define KBUILD_MODNAME "foo"
51#include <uapi/linux/bpf.h>
52#include <linux/in.h>
53#include <linux/if_ether.h>
54#include <linux/if_packet.h>
55#include <linux/if_vlan.h>
56#include <linux/ip.h>
57#include <linux/ipv6.h>
58
59
60BPF_TABLE("percpu_array", uint32_t, long, dropcnt, 256);
61
62static inline int parse_ipv4(void *data, u64 nh_off, void *data_end) {
63    struct iphdr *iph = data + nh_off;
64
65    if ((void*)&iph[1] > data_end)
66        return 0;
67    return iph->protocol;
68}
69
70static inline int parse_ipv6(void *data, u64 nh_off, void *data_end) {
71    struct ipv6hdr *ip6h = data + nh_off;
72
73    if ((void*)&ip6h[1] > data_end)
74        return 0;
75    return ip6h->nexthdr;
76}
77
78int xdp_prog1(struct CTXTYPE *ctx) {
79
80    void* data_end = (void*)(long)ctx->data_end;
81    void* data = (void*)(long)ctx->data;
82
83    struct ethhdr *eth = data;
84
85    // drop packets
86    int rc = RETURNCODE; // let pass XDP_PASS or redirect to tx via XDP_TX
87    long *value;
88    uint16_t h_proto;
89    uint64_t nh_off = 0;
90    uint32_t index;
91
92    nh_off = sizeof(*eth);
93
94    if (data + nh_off  > data_end)
95        return rc;
96
97    h_proto = eth->h_proto;
98
99    if (h_proto == htons(ETH_P_8021Q) || h_proto == htons(ETH_P_8021AD)) {
100        struct vlan_hdr *vhdr;
101
102        vhdr = data + nh_off;
103        nh_off += sizeof(struct vlan_hdr);
104        if (data + nh_off > data_end)
105            return rc;
106            h_proto = vhdr->h_vlan_encapsulated_proto;
107    }
108    if (h_proto == htons(ETH_P_8021Q) || h_proto == htons(ETH_P_8021AD)) {
109        struct vlan_hdr *vhdr;
110
111        vhdr = data + nh_off;
112        nh_off += sizeof(struct vlan_hdr);
113        if (data + nh_off > data_end)
114            return rc;
115            h_proto = vhdr->h_vlan_encapsulated_proto;
116    }
117
118    if (h_proto == htons(ETH_P_IP))
119        index = parse_ipv4(data, nh_off, data_end);
120    else if (h_proto == htons(ETH_P_IPV6))
121       index = parse_ipv6(data, nh_off, data_end);
122    else
123        index = 0;
124
125    value = dropcnt.lookup(&index);
126    if (value)
127        *value += 1;
128
129    return rc;
130}
131""", cflags=["-w", "-DRETURNCODE=%s" % ret, "-DCTXTYPE=%s" % ctxtype])
132
133fn = b.load_func("xdp_prog1", mode)
134
135if mode == BPF.XDP:
136    b.attach_xdp(device, fn, flags)
137else:
138    ip = pyroute2.IPRoute()
139    ipdb = pyroute2.IPDB(nl=ip)
140    idx = ipdb.interfaces[device].index
141    ip.tc("add", "clsact", idx)
142    ip.tc("add-filter", "bpf", idx, ":1", fd=fn.fd, name=fn.name,
143          parent="ffff:fff2", classid=1, direct_action=True)
144
145dropcnt = b.get_table("dropcnt")
146prev = [0] * 256
147print("Printing drops per IP protocol-number, hit CTRL+C to stop")
148while 1:
149    try:
150        for k in dropcnt.keys():
151            val = dropcnt.sum(k).value
152            i = k.value
153            if val:
154                delta = val - prev[i]
155                prev[i] = val
156                print("{}: {} pkt/s".format(i, delta))
157        time.sleep(1)
158    except KeyboardInterrupt:
159        print("Removing filter from device")
160        break;
161
162if mode == BPF.XDP:
163    b.remove_xdp(device, flags)
164else:
165    ip.tc("del", "clsact", idx)
166    ipdb.release()
167