1## This file is part of Scapy
2## See http://www.secdev.org/projects/scapy for more informations
3## Copyright (C) Philippe Biondi <phil@secdev.org>
4## This program is published under a GPLv2 license
5
6"""
7PacketList: holds several packets and allows to do operations on them.
8"""
9
10
11from __future__ import absolute_import
12from __future__ import print_function
13import os,subprocess
14from collections import defaultdict
15
16from scapy.config import conf
17from scapy.base_classes import BasePacket,BasePacketList
18from scapy.utils import do_graph,hexdump,make_table,make_lined_table,make_tex_table,get_temp_file
19
20from scapy.consts import plt, MATPLOTLIB_INLINED, MATPLOTLIB_DEFAULT_PLOT_KARGS
21from functools import reduce
22import scapy.modules.six as six
23from scapy.modules.six.moves import filter, range, zip
24
25
26#############
27## Results ##
28#############
29
30class PacketList(BasePacketList):
31    __slots__ = ["stats", "res", "listname"]
32    def __init__(self, res=None, name="PacketList", stats=None):
33        """create a packet list from a list of packets
34           res: the list of packets
35           stats: a list of classes that will appear in the stats (defaults to [TCP,UDP,ICMP])"""
36        if stats is None:
37            stats = conf.stats_classic_protocols
38        self.stats = stats
39        if res is None:
40            res = []
41        elif isinstance(res, PacketList):
42            res = res.res
43        self.res = res
44        self.listname = name
45    def __len__(self):
46        return len(self.res)
47    def _elt2pkt(self, elt):
48        return elt
49    def _elt2sum(self, elt):
50        return elt.summary()
51    def _elt2show(self, elt):
52        return self._elt2sum(elt)
53    def __repr__(self):
54        stats = {x: 0 for x in self.stats}
55        other = 0
56        for r in self.res:
57            f = 0
58            for p in stats:
59                if self._elt2pkt(r).haslayer(p):
60                    stats[p] += 1
61                    f = 1
62                    break
63            if not f:
64                other += 1
65        s = ""
66        ct = conf.color_theme
67        for p in self.stats:
68            s += " %s%s%s" % (ct.packetlist_proto(p._name),
69                              ct.punct(":"),
70                              ct.packetlist_value(stats[p]))
71        s += " %s%s%s" % (ct.packetlist_proto("Other"),
72                          ct.punct(":"),
73                          ct.packetlist_value(other))
74        return "%s%s%s%s%s" % (ct.punct("<"),
75                               ct.packetlist_name(self.listname),
76                               ct.punct(":"),
77                               s,
78                               ct.punct(">"))
79    def __getattr__(self, attr):
80        return getattr(self.res, attr)
81    def __getitem__(self, item):
82        if isinstance(item,type) and issubclass(item,BasePacket):
83            return self.__class__([x for x in self.res if item in self._elt2pkt(x)],
84                                  name="%s from %s"%(item.__name__,self.listname))
85        if isinstance(item, slice):
86            return self.__class__(self.res.__getitem__(item),
87                                  name = "mod %s" % self.listname)
88        return self.res.__getitem__(item)
89    def __getslice__(self, *args, **kargs):
90        return self.__class__(self.res.__getslice__(*args, **kargs),
91                              name="mod %s"%self.listname)
92    def __add__(self, other):
93        return self.__class__(self.res+other.res,
94                              name="%s+%s"%(self.listname,other.listname))
95    def summary(self, prn=None, lfilter=None):
96        """prints a summary of each packet
97prn:     function to apply to each packet instead of lambda x:x.summary()
98lfilter: truth function to apply to each packet to decide whether it will be displayed"""
99        for r in self.res:
100            if lfilter is not None:
101                if not lfilter(r):
102                    continue
103            if prn is None:
104                print(self._elt2sum(r))
105            else:
106                print(prn(r))
107    def nsummary(self, prn=None, lfilter=None):
108        """prints a summary of each packet with the packet's number
109prn:     function to apply to each packet instead of lambda x:x.summary()
110lfilter: truth function to apply to each packet to decide whether it will be displayed"""
111        for i, res in enumerate(self.res):
112            if lfilter is not None:
113                if not lfilter(res):
114                    continue
115            print(conf.color_theme.id(i,fmt="%04i"), end=' ')
116            if prn is None:
117                print(self._elt2sum(res))
118            else:
119                print(prn(res))
120    def display(self): # Deprecated. Use show()
121        """deprecated. is show()"""
122        self.show()
123    def show(self, *args, **kargs):
124        """Best way to display the packet list. Defaults to nsummary() method"""
125        return self.nsummary(*args, **kargs)
126
127    def filter(self, func):
128        """Returns a packet list filtered by a truth function"""
129        return self.__class__([x for x in self.res if func(x)],
130                              name="filtered %s"%self.listname)
131    def make_table(self, *args, **kargs):
132        """Prints a table using a function that returns for each packet its head column value, head row value and displayed value
133        ex: p.make_table(lambda x:(x[IP].dst, x[TCP].dport, x[TCP].sprintf("%flags%")) """
134        return make_table(self.res, *args, **kargs)
135    def make_lined_table(self, *args, **kargs):
136        """Same as make_table, but print a table with lines"""
137        return make_lined_table(self.res, *args, **kargs)
138    def make_tex_table(self, *args, **kargs):
139        """Same as make_table, but print a table with LaTeX syntax"""
140        return make_tex_table(self.res, *args, **kargs)
141
142    def plot(self, f, lfilter=None, plot_xy=False, **kargs):
143        """Applies a function to each packet to get a value that will be plotted
144        with matplotlib. A list of matplotlib.lines.Line2D is returned.
145
146        lfilter: a truth function that decides whether a packet must be plotted
147        """
148
149        # Get the list of packets
150        if lfilter is None:
151            l = [f(e) for e in self.res]
152        else:
153            l = [f(e) for e in self.res if lfilter(e)]
154
155        # Mimic the default gnuplot output
156        if kargs == {}:
157            kargs = MATPLOTLIB_DEFAULT_PLOT_KARGS
158        if plot_xy:
159            lines = plt.plot(*zip(*l), **kargs)
160        else:
161            lines = plt.plot(l, **kargs)
162
163        # Call show() if matplotlib is not inlined
164        if not MATPLOTLIB_INLINED:
165            plt.show()
166
167        return lines
168
169    def diffplot(self, f, delay=1, lfilter=None, **kargs):
170        """diffplot(f, delay=1, lfilter=None)
171        Applies a function to couples (l[i],l[i+delay])
172
173        A list of matplotlib.lines.Line2D is returned.
174        """
175
176        # Get the list of packets
177        if lfilter is None:
178            l = [f(self.res[i], self.res[i+1])
179                    for i in range(len(self.res) - delay)]
180        else:
181            l = [f(self.res[i], self.res[i+1])
182                    for i in range(len(self.res) - delay)
183                        if lfilter(self.res[i])]
184
185        # Mimic the default gnuplot output
186        if kargs == {}:
187            kargs = MATPLOTLIB_DEFAULT_PLOT_KARGS
188        lines = plt.plot(l, **kargs)
189
190        # Call show() if matplotlib is not inlined
191        if not MATPLOTLIB_INLINED:
192            plt.show()
193
194        return lines
195
196    def multiplot(self, f, lfilter=None, plot_xy=False, **kargs):
197        """Uses a function that returns a label and a value for this label, then
198        plots all the values label by label.
199
200        A list of matplotlib.lines.Line2D is returned.
201        """
202
203        # Get the list of packets
204        if lfilter is None:
205            l = (f(e) for e in self.res)
206        else:
207            l = (f(e) for e in self.res if lfilter(e))
208
209        # Apply the function f to the packets
210        d = {}
211        for k, v in l:
212            d.setdefault(k, []).append(v)
213
214        # Mimic the default gnuplot output
215        if not kargs:
216            kargs = MATPLOTLIB_DEFAULT_PLOT_KARGS
217
218        if plot_xy:
219            lines = [plt.plot(*zip(*pl), **dict(kargs, label=k))
220                     for k, pl in six.iteritems(d)]
221        else:
222            lines = [plt.plot(pl, **dict(kargs, label=k))
223                     for k, pl in six.iteritems(d)]
224        plt.legend(loc="center right", bbox_to_anchor=(1.5, 0.5))
225
226        # Call show() if matplotlib is not inlined
227        if not MATPLOTLIB_INLINED:
228            plt.show()
229
230        return lines
231
232    def rawhexdump(self):
233        """Prints an hexadecimal dump of each packet in the list"""
234        for p in self:
235            hexdump(self._elt2pkt(p))
236
237    def hexraw(self, lfilter=None):
238        """Same as nsummary(), except that if a packet has a Raw layer, it will be hexdumped
239        lfilter: a truth function that decides whether a packet must be displayed"""
240        for i, res in enumerate(self.res):
241            p = self._elt2pkt(res)
242            if lfilter is not None and not lfilter(p):
243                continue
244            print("%s %s %s" % (conf.color_theme.id(i,fmt="%04i"),
245                                p.sprintf("%.time%"),
246                                self._elt2sum(res)))
247            if p.haslayer(conf.raw_layer):
248                hexdump(p.getlayer(conf.raw_layer).load)
249
250    def hexdump(self, lfilter=None):
251        """Same as nsummary(), except that packets are also hexdumped
252        lfilter: a truth function that decides whether a packet must be displayed"""
253        for i, res in enumerate(self.res):
254            p = self._elt2pkt(res)
255            if lfilter is not None and not lfilter(p):
256                continue
257            print("%s %s %s" % (conf.color_theme.id(i,fmt="%04i"),
258                                p.sprintf("%.time%"),
259                                self._elt2sum(res)))
260            hexdump(p)
261
262    def padding(self, lfilter=None):
263        """Same as hexraw(), for Padding layer"""
264        for i, res in enumerate(self.res):
265            p = self._elt2pkt(res)
266            if p.haslayer(conf.padding_layer):
267                if lfilter is None or lfilter(p):
268                    print("%s %s %s" % (conf.color_theme.id(i,fmt="%04i"),
269                                        p.sprintf("%.time%"),
270                                        self._elt2sum(res)))
271                    hexdump(p.getlayer(conf.padding_layer).load)
272
273    def nzpadding(self, lfilter=None):
274        """Same as padding() but only non null padding"""
275        for i, res in enumerate(self.res):
276            p = self._elt2pkt(res)
277            if p.haslayer(conf.padding_layer):
278                pad = p.getlayer(conf.padding_layer).load
279                if pad == pad[0]*len(pad):
280                    continue
281                if lfilter is None or lfilter(p):
282                    print("%s %s %s" % (conf.color_theme.id(i,fmt="%04i"),
283                                        p.sprintf("%.time%"),
284                                        self._elt2sum(res)))
285                    hexdump(p.getlayer(conf.padding_layer).load)
286
287
288    def conversations(self, getsrcdst=None,**kargs):
289        """Graphes a conversations between sources and destinations and display it
290        (using graphviz and imagemagick)
291        getsrcdst: a function that takes an element of the list and
292                   returns the source, the destination and optionally
293                   a label. By default, returns the IP source and
294                   destination from IP and ARP layers
295        type: output type (svg, ps, gif, jpg, etc.), passed to dot's "-T" option
296        target: filename or redirect. Defaults pipe to Imagemagick's display program
297        prog: which graphviz program to use"""
298        if getsrcdst is None:
299            def getsrcdst(pkt):
300                if 'IP' in pkt:
301                    return (pkt['IP'].src, pkt['IP'].dst)
302                if 'ARP' in pkt:
303                    return (pkt['ARP'].psrc, pkt['ARP'].pdst)
304                raise TypeError()
305        conv = {}
306        for p in self.res:
307            p = self._elt2pkt(p)
308            try:
309                c = getsrcdst(p)
310            except:
311                # No warning here: it's OK that getsrcdst() raises an
312                # exception, since it might be, for example, a
313                # function that expects a specific layer in each
314                # packet. The try/except approach is faster and
315                # considered more Pythonic than adding tests.
316                continue
317            if len(c) == 3:
318                conv.setdefault(c[:2], set()).add(c[2])
319            else:
320                conv[c] = conv.get(c, 0) + 1
321        gr = 'digraph "conv" {\n'
322        for (s, d), l in six.iteritems(conv):
323            gr += '\t "%s" -> "%s" [label="%s"]\n' % (
324                s, d, ', '.join(str(x) for x in l) if isinstance(l, set) else l
325            )
326        gr += "}\n"
327        return do_graph(gr, **kargs)
328
329    def afterglow(self, src=None, event=None, dst=None, **kargs):
330        """Experimental clone attempt of http://sourceforge.net/projects/afterglow
331        each datum is reduced as src -> event -> dst and the data are graphed.
332        by default we have IP.src -> IP.dport -> IP.dst"""
333        if src is None:
334            src = lambda x: x['IP'].src
335        if event is None:
336            event = lambda x: x['IP'].dport
337        if dst is None:
338            dst = lambda x: x['IP'].dst
339        sl = {}
340        el = {}
341        dl = {}
342        for i in self.res:
343            try:
344                s,e,d = src(i),event(i),dst(i)
345                if s in sl:
346                    n,l = sl[s]
347                    n += 1
348                    if e not in l:
349                        l.append(e)
350                    sl[s] = (n,l)
351                else:
352                    sl[s] = (1,[e])
353                if e in el:
354                    n,l = el[e]
355                    n+=1
356                    if d not in l:
357                        l.append(d)
358                    el[e] = (n,l)
359                else:
360                    el[e] = (1,[d])
361                dl[d] = dl.get(d,0)+1
362            except:
363                continue
364
365        import math
366        def normalize(n):
367            return 2+math.log(n)/4.0
368
369        def minmax(x):
370            m, M = reduce(lambda a, b: (min(a[0], b[0]), max(a[1], b[1])),
371                          ((a, a) for a in x))
372            if m == M:
373                m = 0
374            if M == 0:
375                M = 1
376            return m, M
377
378        mins, maxs = minmax(x for x, _ in six.itervalues(sl))
379        mine, maxe = minmax(x for x, _ in six.itervalues(el))
380        mind, maxd = minmax(six.itervalues(dl))
381
382        gr = 'digraph "afterglow" {\n\tedge [len=2.5];\n'
383
384        gr += "# src nodes\n"
385        for s in sl:
386            n,l = sl[s]; n = 1+float(n-mins)/(maxs-mins)
387            gr += '"src.%s" [label = "%s", shape=box, fillcolor="#FF0000", style=filled, fixedsize=1, height=%.2f,width=%.2f];\n' % (repr(s),repr(s),n,n)
388        gr += "# event nodes\n"
389        for e in el:
390            n,l = el[e]; n = n = 1+float(n-mine)/(maxe-mine)
391            gr += '"evt.%s" [label = "%s", shape=circle, fillcolor="#00FFFF", style=filled, fixedsize=1, height=%.2f, width=%.2f];\n' % (repr(e),repr(e),n,n)
392        for d in dl:
393            n = dl[d]; n = n = 1+float(n-mind)/(maxd-mind)
394            gr += '"dst.%s" [label = "%s", shape=triangle, fillcolor="#0000ff", style=filled, fixedsize=1, height=%.2f, width=%.2f];\n' % (repr(d),repr(d),n,n)
395
396        gr += "###\n"
397        for s in sl:
398            n,l = sl[s]
399            for e in l:
400                gr += ' "src.%s" -> "evt.%s";\n' % (repr(s),repr(e))
401        for e in el:
402            n,l = el[e]
403            for d in l:
404                gr += ' "evt.%s" -> "dst.%s";\n' % (repr(e),repr(d))
405
406        gr += "}"
407        return do_graph(gr, **kargs)
408
409
410    def _dump_document(self, **kargs):
411        import pyx
412        d = pyx.document.document()
413        l = len(self.res)
414        for i, res in enumerate(self.res):
415            c = self._elt2pkt(res).canvas_dump(**kargs)
416            cbb = c.bbox()
417            c.text(cbb.left(),cbb.top()+1,r"\font\cmssfont=cmss12\cmssfont{Frame %i/%i}" % (i,l),[pyx.text.size.LARGE])
418            if conf.verb >= 2:
419                os.write(1, b".")
420            d.append(pyx.document.page(c, paperformat=pyx.document.paperformat.A4,
421                                       margin=1*pyx.unit.t_cm,
422                                       fittosize=1))
423        return d
424
425
426
427    def psdump(self, filename = None, **kargs):
428        """Creates a multi-page postcript file with a psdump of every packet
429        filename: name of the file to write to. If empty, a temporary file is used and
430                  conf.prog.psreader is called"""
431        d = self._dump_document(**kargs)
432        if filename is None:
433            filename = get_temp_file(autoext=".ps")
434            d.writePSfile(filename)
435            with ContextManagerSubprocess("psdump()"):
436                subprocess.Popen([conf.prog.psreader, filename+".ps"])
437        else:
438            d.writePSfile(filename)
439        print()
440
441    def pdfdump(self, filename = None, **kargs):
442        """Creates a PDF file with a psdump of every packet
443        filename: name of the file to write to. If empty, a temporary file is used and
444                  conf.prog.pdfreader is called"""
445        d = self._dump_document(**kargs)
446        if filename is None:
447            filename = get_temp_file(autoext=".pdf")
448            d.writePDFfile(filename)
449            with ContextManagerSubprocess("psdump()"):
450                subprocess.Popen([conf.prog.pdfreader, filename+".pdf"])
451        else:
452            d.writePDFfile(filename)
453        print()
454
455    def sr(self,multi=0):
456        """sr([multi=1]) -> (SndRcvList, PacketList)
457        Matches packets in the list and return ( (matched couples), (unmatched packets) )"""
458        remain = self.res[:]
459        sr = []
460        i = 0
461        while i < len(remain):
462            s = remain[i]
463            j = i
464            while j < len(remain)-1:
465                j += 1
466                r = remain[j]
467                if r.answers(s):
468                    sr.append((s,r))
469                    if multi:
470                        remain[i]._answered=1
471                        remain[j]._answered=2
472                        continue
473                    del(remain[j])
474                    del(remain[i])
475                    i -= 1
476                    break
477            i += 1
478        if multi:
479            remain = [x for x in remain if not hasattr(x, "_answered")]
480        return SndRcvList(sr),PacketList(remain)
481
482    def sessions(self, session_extractor=None):
483        if session_extractor is None:
484            def session_extractor(p):
485                sess = "Other"
486                if 'Ether' in p:
487                    if 'IP' in p:
488                        if 'TCP' in p:
489                            sess = p.sprintf("TCP %IP.src%:%r,TCP.sport% > %IP.dst%:%r,TCP.dport%")
490                        elif 'UDP' in p:
491                            sess = p.sprintf("UDP %IP.src%:%r,UDP.sport% > %IP.dst%:%r,UDP.dport%")
492                        elif 'ICMP' in p:
493                            sess = p.sprintf("ICMP %IP.src% > %IP.dst% type=%r,ICMP.type% code=%r,ICMP.code% id=%ICMP.id%")
494                        else:
495                            sess = p.sprintf("IP %IP.src% > %IP.dst% proto=%IP.proto%")
496                    elif 'ARP' in p:
497                        sess = p.sprintf("ARP %ARP.psrc% > %ARP.pdst%")
498                    else:
499                        sess = p.sprintf("Ethernet type=%04xr,Ether.type%")
500                return sess
501        sessions = defaultdict(self.__class__)
502        for p in self.res:
503            sess = session_extractor(self._elt2pkt(p))
504            sessions[sess].append(p)
505        return dict(sessions)
506
507    def replace(self, *args, **kargs):
508        """
509        lst.replace(<field>,[<oldvalue>,]<newvalue>)
510        lst.replace( (fld,[ov],nv),(fld,[ov,]nv),...)
511          if ov is None, all values are replaced
512        ex:
513          lst.replace( IP.src, "192.168.1.1", "10.0.0.1" )
514          lst.replace( IP.ttl, 64 )
515          lst.replace( (IP.ttl, 64), (TCP.sport, 666, 777), )
516        """
517        delete_checksums = kargs.get("delete_checksums",False)
518        x=PacketList(name="Replaced %s" % self.listname)
519        if not isinstance(args[0], tuple):
520            args = (args,)
521        for p in self.res:
522            p = self._elt2pkt(p)
523            copied = False
524            for scheme in args:
525                fld = scheme[0]
526                old = scheme[1] # not used if len(scheme) == 2
527                new = scheme[-1]
528                for o in fld.owners:
529                    if o in p:
530                        if len(scheme) == 2 or p[o].getfieldval(fld.name) == old:
531                            if not copied:
532                                p = p.copy()
533                                if delete_checksums:
534                                    p.delete_checksums()
535                                copied = True
536                            setattr(p[o], fld.name, new)
537            x.append(p)
538        return x
539
540
541class SndRcvList(PacketList):
542    __slots__ = []
543    def __init__(self, res=None, name="Results", stats=None):
544        PacketList.__init__(self, res, name, stats)
545    def _elt2pkt(self, elt):
546        return elt[1]
547    def _elt2sum(self, elt):
548        return "%s ==> %s" % (elt[0].summary(),elt[1].summary())
549