1#!/usr/bin/python
2# @lint-avoid-python-3-compatibility-imports
3#
4# tcpstates   Trace the TCP session state changes with durations.
5#             For Linux, uses BCC, BPF. Embedded C.
6#
7# USAGE: tcpstates [-h] [-C] [-S] [interval [count]]
8#
9# This uses the sock:inet_sock_set_state tracepoint, added to Linux 4.16.
10# Linux 4.16 also adds more state transitions so that they can be traced.
11#
12# Copyright 2018 Netflix, Inc.
13# Licensed under the Apache License, Version 2.0 (the "License")
14#
15# 20-Mar-2018   Brendan Gregg   Created this.
16
17from __future__ import print_function
18from bcc import BPF
19import argparse
20from socket import inet_ntop, AF_INET, AF_INET6
21from struct import pack
22import ctypes as ct
23from time import strftime
24
25# arguments
26examples = """examples:
27    ./tcpstates           # trace all TCP state changes
28    ./tcpstates -t        # include timestamp column
29    ./tcpstates -T        # include time column (HH:MM:SS)
30    ./tcpstates -w        # wider colums (fit IPv6)
31    ./tcpstates -stT      # csv output, with times & timestamps
32    ./tcpstates -L 80     # only trace local port 80
33    ./tcpstates -L 80,81  # only trace local ports 80 and 81
34    ./tcpstates -D 80     # only trace remote port 80
35"""
36parser = argparse.ArgumentParser(
37    description="Trace TCP session state changes and durations",
38    formatter_class=argparse.RawDescriptionHelpFormatter,
39    epilog=examples)
40parser.add_argument("-T", "--time", action="store_true",
41    help="include time column on output (HH:MM:SS)")
42parser.add_argument("-t", "--timestamp", action="store_true",
43    help="include timestamp on output (seconds)")
44parser.add_argument("-w", "--wide", action="store_true",
45    help="wide column output (fits IPv6 addresses)")
46parser.add_argument("-s", "--csv", action="store_true",
47    help="comma separated values output")
48parser.add_argument("-L", "--localport",
49    help="comma-separated list of local ports to trace.")
50parser.add_argument("-D", "--remoteport",
51    help="comma-separated list of remote ports to trace.")
52parser.add_argument("--ebpf", action="store_true",
53    help=argparse.SUPPRESS)
54args = parser.parse_args()
55debug = 0
56
57# define BPF program
58bpf_text = """
59#include <uapi/linux/ptrace.h>
60#define KBUILD_MODNAME "foo"
61#include <linux/tcp.h>
62#include <net/sock.h>
63#include <bcc/proto.h>
64
65BPF_HASH(last, struct sock *, u64);
66
67// separate data structs for ipv4 and ipv6
68struct ipv4_data_t {
69    u64 ts_us;
70    u64 skaddr;
71    u32 saddr;
72    u32 daddr;
73    u64 span_us;
74    u32 pid;
75    u32 ports;
76    u32 oldstate;
77    u32 newstate;
78    char task[TASK_COMM_LEN];
79};
80BPF_PERF_OUTPUT(ipv4_events);
81
82struct ipv6_data_t {
83    u64 ts_us;
84    u64 skaddr;
85    unsigned __int128 saddr;
86    unsigned __int128 daddr;
87    u64 span_us;
88    u32 pid;
89    u32 ports;
90    u32 oldstate;
91    u32 newstate;
92    char task[TASK_COMM_LEN];
93};
94BPF_PERF_OUTPUT(ipv6_events);
95
96struct id_t {
97    u32 pid;
98    char task[TASK_COMM_LEN];
99};
100
101TRACEPOINT_PROBE(sock, inet_sock_set_state)
102{
103    if (args->protocol != IPPROTO_TCP)
104        return 0;
105
106    u32 pid = bpf_get_current_pid_tgid() >> 32;
107    // sk is used as a UUID
108    struct sock *sk = (struct sock *)args->skaddr;
109
110    // lport is either used in a filter here, or later
111    u16 lport = args->sport;
112    FILTER_LPORT
113
114    // dport is either used in a filter here, or later
115    u16 dport = args->dport;
116    FILTER_DPORT
117
118    // calculate delta
119    u64 *tsp, delta_us;
120    tsp = last.lookup(&sk);
121    if (tsp == 0)
122        delta_us = 0;
123    else
124        delta_us = (bpf_ktime_get_ns() - *tsp) / 1000;
125
126    if (args->family == AF_INET) {
127        struct ipv4_data_t data4 = {
128            .span_us = delta_us,
129            .oldstate = args->oldstate,
130            .newstate = args->newstate };
131        data4.skaddr = (u64)args->skaddr;
132        data4.ts_us = bpf_ktime_get_ns() / 1000;
133        __builtin_memcpy(&data4.saddr, args->saddr, sizeof(data4.saddr));
134        __builtin_memcpy(&data4.daddr, args->daddr, sizeof(data4.daddr));
135        // a workaround until data4 compiles with separate lport/dport
136        data4.ports = dport + ((0ULL + lport) << 32);
137        data4.pid = pid;
138
139        bpf_get_current_comm(&data4.task, sizeof(data4.task));
140        ipv4_events.perf_submit(args, &data4, sizeof(data4));
141
142    } else /* 6 */ {
143        struct ipv6_data_t data6 = {
144            .span_us = delta_us,
145            .oldstate = args->oldstate,
146            .newstate = args->newstate };
147        data6.skaddr = (u64)args->skaddr;
148        data6.ts_us = bpf_ktime_get_ns() / 1000;
149        __builtin_memcpy(&data6.saddr, args->saddr_v6, sizeof(data6.saddr));
150        __builtin_memcpy(&data6.daddr, args->daddr_v6, sizeof(data6.daddr));
151        // a workaround until data6 compiles with separate lport/dport
152        data6.ports = dport + ((0ULL + lport) << 32);
153        data6.pid = pid;
154        bpf_get_current_comm(&data6.task, sizeof(data6.task));
155        ipv6_events.perf_submit(args, &data6, sizeof(data6));
156    }
157
158    u64 ts = bpf_ktime_get_ns();
159    last.update(&sk, &ts);
160
161    return 0;
162}
163"""
164
165if (not BPF.tracepoint_exists("sock", "inet_sock_set_state")):
166    print("ERROR: tracepoint sock:inet_sock_set_state missing "
167        "(added in Linux 4.16). Exiting")
168    exit()
169
170# code substitutions
171if args.remoteport:
172    dports = [int(dport) for dport in args.remoteport.split(',')]
173    dports_if = ' && '.join(['dport != %d' % dport for dport in dports])
174    bpf_text = bpf_text.replace('FILTER_DPORT',
175        'if (%s) { last.delete(&sk); return 0; }' % dports_if)
176if args.localport:
177    lports = [int(lport) for lport in args.localport.split(',')]
178    lports_if = ' && '.join(['lport != %d' % lport for lport in lports])
179    bpf_text = bpf_text.replace('FILTER_LPORT',
180        'if (%s) { last.delete(&sk); return 0; }' % lports_if)
181bpf_text = bpf_text.replace('FILTER_DPORT', '')
182bpf_text = bpf_text.replace('FILTER_LPORT', '')
183
184if debug or args.ebpf:
185    print(bpf_text)
186    if args.ebpf:
187        exit()
188
189# event data
190TASK_COMM_LEN = 16      # linux/sched.h
191
192class Data_ipv4(ct.Structure):
193    _fields_ = [
194        ("ts_us", ct.c_ulonglong),
195        ("skaddr", ct.c_ulonglong),
196        ("saddr", ct.c_uint),
197        ("daddr", ct.c_uint),
198        ("span_us", ct.c_ulonglong),
199        ("pid", ct.c_uint),
200        ("ports", ct.c_uint),
201        ("oldstate", ct.c_uint),
202        ("newstate", ct.c_uint),
203        ("task", ct.c_char * TASK_COMM_LEN)
204    ]
205
206class Data_ipv6(ct.Structure):
207    _fields_ = [
208        ("ts_us", ct.c_ulonglong),
209        ("skaddr", ct.c_ulonglong),
210        ("saddr", (ct.c_ulonglong * 2)),
211        ("daddr", (ct.c_ulonglong * 2)),
212        ("span_us", ct.c_ulonglong),
213        ("pid", ct.c_uint),
214        ("ports", ct.c_uint),
215        ("oldstate", ct.c_uint),
216        ("newstate", ct.c_uint),
217        ("task", ct.c_char * TASK_COMM_LEN)
218    ]
219
220#
221# Setup output formats
222#
223# Don't change the default output (next 2 lines): this fits in 80 chars. I
224# know it doesn't have NS or UIDs etc. I know. If you really, really, really
225# need to add columns, columns that solve real actual problems, I'd start by
226# adding an extended mode (-x) to included those columns.
227#
228header_string = "%-16s %-5s %-10.10s %s%-15s %-5s %-15s %-5s %-11s -> %-11s %s"
229format_string = ("%-16x %-5d %-10.10s %s%-15s %-5d %-15s %-5d %-11s " +
230    "-> %-11s %.3f")
231if args.wide:
232    header_string = ("%-16s %-5s %-16.16s %-2s %-26s %-5s %-26s %-5s %-11s " +
233        "-> %-11s %s")
234    format_string = ("%-16x %-5d %-16.16s %-2s %-26s %-5s %-26s %-5d %-11s " +
235        "-> %-11s %.3f")
236if args.csv:
237    header_string = "%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s"
238    format_string = "%x,%d,%s,%s,%s,%s,%s,%d,%s,%s,%.3f"
239
240def tcpstate2str(state):
241    # from include/net/tcp_states.h:
242    tcpstate = {
243        1: "ESTABLISHED",
244        2: "SYN_SENT",
245        3: "SYN_RECV",
246        4: "FIN_WAIT1",
247        5: "FIN_WAIT2",
248        6: "TIME_WAIT",
249        7: "CLOSE",
250        8: "CLOSE_WAIT",
251        9: "LAST_ACK",
252        10: "LISTEN",
253        11: "CLOSING",
254        12: "NEW_SYN_RECV",
255    }
256
257    if state in tcpstate:
258        return tcpstate[state]
259    else:
260        return str(state)
261
262# process event
263def print_ipv4_event(cpu, data, size):
264    event = ct.cast(data, ct.POINTER(Data_ipv4)).contents
265    global start_ts
266    if args.time:
267        if args.csv:
268            print("%s," % strftime("%H:%M:%S"), end="")
269        else:
270            print("%-8s " % strftime("%H:%M:%S"), end="")
271    if args.timestamp:
272        if start_ts == 0:
273            start_ts = event.ts_us
274        delta_s = (float(event.ts_us) - start_ts) / 1000000
275        if args.csv:
276            print("%.6f," % delta_s, end="")
277        else:
278            print("%-9.6f " % delta_s, end="")
279    print(format_string % (event.skaddr, event.pid, event.task.decode('utf-8', 'replace'),
280        "4" if args.wide or args.csv else "",
281        inet_ntop(AF_INET, pack("I", event.saddr)), event.ports >> 32,
282        inet_ntop(AF_INET, pack("I", event.daddr)), event.ports & 0xffffffff,
283        tcpstate2str(event.oldstate), tcpstate2str(event.newstate),
284        float(event.span_us) / 1000))
285
286def print_ipv6_event(cpu, data, size):
287    event = ct.cast(data, ct.POINTER(Data_ipv6)).contents
288    global start_ts
289    if args.time:
290        if args.csv:
291            print("%s," % strftime("%H:%M:%S"), end="")
292        else:
293            print("%-8s " % strftime("%H:%M:%S"), end="")
294    if args.timestamp:
295        if start_ts == 0:
296            start_ts = event.ts_us
297        delta_s = (float(event.ts_us) - start_ts) / 1000000
298        if args.csv:
299            print("%.6f," % delta_s, end="")
300        else:
301            print("%-9.6f " % delta_s, end="")
302    print(format_string % (event.skaddr, event.pid, event.task.decode('utf-8', 'replace'),
303        "6" if args.wide or args.csv else "",
304        inet_ntop(AF_INET6, event.saddr), event.ports >> 32,
305        inet_ntop(AF_INET6, event.daddr), event.ports & 0xffffffff,
306        tcpstate2str(event.oldstate), tcpstate2str(event.newstate),
307        float(event.span_us) / 1000))
308
309# initialize BPF
310b = BPF(text=bpf_text)
311
312# header
313if args.time:
314    if args.csv:
315        print("%s," % ("TIME"), end="")
316    else:
317        print("%-8s " % ("TIME"), end="")
318if args.timestamp:
319    if args.csv:
320        print("%s," % ("TIME(s)"), end="")
321    else:
322        print("%-9s " % ("TIME(s)"), end="")
323print(header_string % ("SKADDR", "C-PID", "C-COMM",
324    "IP" if args.wide or args.csv else "",
325    "LADDR", "LPORT", "RADDR", "RPORT",
326    "OLDSTATE", "NEWSTATE", "MS"))
327
328start_ts = 0
329
330# read events
331b["ipv4_events"].open_perf_buffer(print_ipv4_event, page_cnt=64)
332b["ipv6_events"].open_perf_buffer(print_ipv6_event, page_cnt=64)
333while 1:
334    b.perf_buffer_poll()
335