1#!/usr/bin/python
2# @lint-avoid-python-3-compatibility-imports
3#
4# tcpsubnet   Summarize TCP bytes sent to different subnets.
5#             For Linux, uses BCC, eBPF. Embedded C.
6#
7# USAGE: tcpsubnet [-h] [-v] [-J] [-f FORMAT] [-i INTERVAL] [subnets]
8#
9# This uses dynamic tracing of kernel functions, and will need to be updated
10# to match kernel changes.
11#
12# This is an adaptation of tcptop from written by Brendan Gregg.
13#
14# WARNING: This traces all send at the TCP level, and while it
15# summarizes data in-kernel to reduce overhead, there may still be some
16# overhead at high TCP send/receive rates (eg, ~13% of one CPU at 100k TCP
17# events/sec. This is not the same as packet rate: funccount can be used to
18# count the kprobes below to find out the TCP rate). Test in a lab environment
19# first. If your send rate is low (eg, <1k/sec) then the overhead is
20# expected to be negligible.
21#
22# Copyright 2017 Rodrigo Manyari
23# Licensed under the Apache License, Version 2.0 (the "License")
24#
25# 03-Oct-2017   Rodrigo Manyari   Created this based on tcptop.
26# 13-Feb-2018   Rodrigo Manyari   Fix pep8 errors, some refactoring.
27# 05-Mar-2018   Rodrigo Manyari   Add date time to output.
28
29import argparse
30import json
31import logging
32import struct
33import socket
34from bcc import BPF
35from datetime import datetime as dt
36from time import sleep
37
38# arguments
39examples = """examples:
40    ./tcpsubnet                 # Trace TCP sent to the default subnets:
41                                # 127.0.0.1/32,10.0.0.0/8,172.16.0.0/12,
42                                # 192.168.0.0/16,0.0.0.0/0
43    ./tcpsubnet -f K            # Trace TCP sent to the default subnets
44                                # aggregated in KBytes.
45    ./tcpsubnet 10.80.0.0/24    # Trace TCP sent to 10.80.0.0/24 only
46    ./tcpsubnet -J              # Format the output in JSON.
47"""
48
49default_subnets = "127.0.0.1/32,10.0.0.0/8," \
50    "172.16.0.0/12,192.168.0.0/16,0.0.0.0/0"
51
52parser = argparse.ArgumentParser(
53    description="Summarize TCP send and aggregate by subnet",
54    formatter_class=argparse.RawDescriptionHelpFormatter,
55    epilog=examples)
56parser.add_argument("subnets", help="comma separated list of subnets",
57    type=str, nargs="?", default=default_subnets)
58parser.add_argument("-v", "--verbose", action="store_true",
59    help="output debug statements")
60parser.add_argument("-J", "--json", action="store_true",
61    help="format output in JSON")
62parser.add_argument("--ebpf", action="store_true",
63    help=argparse.SUPPRESS)
64parser.add_argument("-f", "--format", default="B",
65    help="[bkmBKM] format to report: bits, Kbits, Mbits, bytes, " +
66    "KBytes, MBytes (default B)", choices=["b", "k", "m", "B", "K", "M"])
67parser.add_argument("-i", "--interval", default=1, type=int,
68    help="output interval, in seconds (default 1)")
69args = parser.parse_args()
70
71level = logging.INFO
72if args.verbose:
73    level = logging.DEBUG
74
75logging.basicConfig(level=level)
76
77logging.debug("Starting with the following args:")
78logging.debug(args)
79
80# args checking
81if int(args.interval) <= 0:
82    logging.error("Invalid interval, must be > 0. Exiting.")
83    exit(1)
84else:
85    args.interval = int(args.interval)
86
87# map of supported formats
88formats = {
89    "b": lambda x: (x * 8),
90    "k": lambda x: ((x * 8) / 1024),
91    "m": lambda x: ((x * 8) / pow(1024, 2)),
92    "B": lambda x: x,
93    "K": lambda x: x / 1024,
94    "M": lambda x: x / pow(1024, 2)
95}
96
97# Let's swap the string with the actual numeric value
98# once here so we don't have to do it on every interval
99formatFn = formats[args.format]
100
101# define the basic structure of the BPF program
102bpf_text = """
103#include <uapi/linux/ptrace.h>
104#include <net/sock.h>
105#include <bcc/proto.h>
106
107struct index_key_t {
108  u32 index;
109};
110
111BPF_HASH(ipv4_send_bytes, struct index_key_t);
112
113int kprobe__tcp_sendmsg(struct pt_regs *ctx, struct sock *sk,
114    struct msghdr *msg, size_t size)
115{
116    u16 family = sk->__sk_common.skc_family;
117
118    if (family == AF_INET) {
119        u32 dst = sk->__sk_common.skc_daddr;
120        unsigned categorized = 0;
121        __SUBNETS__
122    }
123    return 0;
124}
125"""
126
127
128# Takes in a mask and returns the integer equivalent
129# e.g.
130# mask_to_int(8) returns 4278190080
131def mask_to_int(n):
132    return ((1 << n) - 1) << (32 - n)
133
134# Takes in a list of subnets and returns a list
135# of tuple-3 containing:
136# - The subnet info at index 0
137# - The addr portion as an int at index 1
138# - The mask portion as an int at index 2
139#
140# e.g.
141# parse_subnets([10.10.0.0/24]) returns
142# [
143#   ['10.10.0.0/24', 168427520, 4294967040],
144# ]
145def parse_subnets(subnets):
146    m = []
147    for s in subnets:
148        parts = s.split("/")
149        if len(parts) != 2:
150            msg = "Subnet [%s] is invalid, please refer to the examples." % s
151            raise ValueError(msg)
152        netaddr_int = 0
153        mask_int = 0
154        try:
155            netaddr_int = struct.unpack("!I", socket.inet_aton(parts[0]))[0]
156        except:
157            msg = ("Invalid net address in subnet [%s], " +
158                "please refer to the examples.") % s
159            raise ValueError(msg)
160        try:
161            mask_int = int(parts[1])
162        except:
163            msg = "Invalid mask in subnet [%s]. Mask must be an int" % s
164            raise ValueError(msg)
165        if mask_int < 0 or mask_int > 32:
166            msg = ("Invalid mask in subnet [%s]. Must be an " +
167                "int between 0 and 32.") % s
168            raise ValueError(msg)
169        mask_int = mask_to_int(int(parts[1]))
170        m.append([s, netaddr_int, mask_int])
171    return m
172
173def generate_bpf_subnets(subnets):
174    template = """
175        if (!categorized && (__NET_ADDR__ & __NET_MASK__) ==
176             (dst & __NET_MASK__)) {
177          struct index_key_t key = {.index = __POS__};
178          ipv4_send_bytes.increment(key, size);
179          categorized = 1;
180        }
181    """
182    bpf = ''
183    for i, s in enumerate(subnets):
184        branch = template
185        branch = branch.replace("__NET_ADDR__", str(socket.htonl(s[1])))
186        branch = branch.replace("__NET_MASK__", str(socket.htonl(s[2])))
187        branch = branch.replace("__POS__", str(i))
188        bpf += branch
189    return bpf
190
191subnets = []
192if args.subnets:
193    subnets = args.subnets.split(",")
194
195subnets = parse_subnets(subnets)
196
197logging.debug("Packets are going to be categorized in the following subnets:")
198logging.debug(subnets)
199
200bpf_subnets = generate_bpf_subnets(subnets)
201
202# initialize BPF
203bpf_text = bpf_text.replace("__SUBNETS__", bpf_subnets)
204
205logging.debug("Done preprocessing the BPF program, " +
206        "this is what will actually get executed:")
207logging.debug(bpf_text)
208
209if args.ebpf:
210    print(bpf_text)
211    exit()
212
213b = BPF(text=bpf_text)
214
215ipv4_send_bytes = b["ipv4_send_bytes"]
216
217if not args.json:
218    print("Tracing... Output every %d secs. Hit Ctrl-C to end" % args.interval)
219
220# output
221exiting = 0
222while (1):
223
224    try:
225        sleep(args.interval)
226    except KeyboardInterrupt:
227        exiting = 1
228
229    # IPv4:  build dict of all seen keys
230    keys = ipv4_send_bytes
231    for k, v in ipv4_send_bytes.items():
232        if k not in keys:
233            keys[k] = v
234
235    # to hold json data
236    data = {}
237
238    # output
239    now = dt.now()
240    data['date'] = now.strftime('%x')
241    data['time'] = now.strftime('%X')
242    data['entries'] = {}
243    if not args.json:
244        print(now.strftime('[%x %X]'))
245    for k, v in reversed(sorted(keys.items(), key=lambda keys: keys[1].value)):
246        send_bytes = 0
247        if k in ipv4_send_bytes:
248            send_bytes = int(ipv4_send_bytes[k].value)
249        subnet = subnets[k.index][0]
250        send = formatFn(send_bytes)
251        if args.json:
252            data['entries'][subnet] = send
253        else:
254            print("%-21s %6d" % (subnet, send))
255
256    if args.json:
257        print(json.dumps(data))
258
259    ipv4_send_bytes.clear()
260
261    if exiting:
262        exit(0)
263