1## This file is part of Scapy
2## See http://www.secdev.org/projects/scapy for more informations
3## Copyright (C) Philippe Biondi <phil@secdev.org>
4## This program is published under a GPLv2 license
5
6"""
7Generators and packet meta classes.
8"""
9
10###############
11## Generators ##
12################
13
14from __future__ import absolute_import
15import re,random,socket
16import types
17from scapy.modules.six.moves import range
18
19class Gen(object):
20    __slots__ = []
21    def __iter__(self):
22        return iter([])
23
24class SetGen(Gen):
25    def __init__(self, values, _iterpacket=1):
26        self._iterpacket=_iterpacket
27        if isinstance(values, (list, BasePacketList)):
28            self.values = list(values)
29        elif (isinstance(values, tuple) and (2 <= len(values) <= 3) and \
30             all(hasattr(i, "__int__") for i in values)):
31            # We use values[1] + 1 as stop value for (x)range to maintain
32            # the behavior of using tuples as field `values`
33            self.values = [range(*((int(values[0]), int(values[1]) + 1)
34                                    + tuple(int(v) for v in values[2:])))]
35        else:
36            self.values = [values]
37    def transf(self, element):
38        return element
39    def __iter__(self):
40        for i in self.values:
41            if (isinstance(i, Gen) and
42                (self._iterpacket or not isinstance(i,BasePacket))) or (
43                    isinstance(i, (range, types.GeneratorType))):
44                for j in i:
45                    yield j
46            else:
47                yield i
48    def __repr__(self):
49        return "<SetGen %r>" % self.values
50
51class Net(Gen):
52    """Generate a list of IPs from a network address or a name"""
53    name = "ip"
54    ip_regex = re.compile(r"^(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)(/[0-3]?[0-9])?$")
55
56    @staticmethod
57    def _parse_digit(a,netmask):
58        netmask = min(8,max(netmask,0))
59        if a == "*":
60            a = (0,256)
61        elif a.find("-") >= 0:
62            x, y = [int(d) for d in a.split('-')]
63            if x > y:
64                y = x
65            a = (x &  (0xff<<netmask) , max(y, (x | (0xff>>(8-netmask))))+1)
66        else:
67            a = (int(a) & (0xff<<netmask),(int(a) | (0xff>>(8-netmask)))+1)
68        return a
69
70    @classmethod
71    def _parse_net(cls, net):
72        tmp=net.split('/')+["32"]
73        if not cls.ip_regex.match(net):
74            tmp[0]=socket.gethostbyname(tmp[0])
75        netmask = int(tmp[1])
76        ret_list = [cls._parse_digit(x, y-netmask) for (x, y) in zip(tmp[0].split('.'), [8, 16, 24, 32])]
77        return ret_list, netmask
78
79    def __init__(self, net):
80        self.repr=net
81        self.parsed,self.netmask = self._parse_net(net)
82
83    def __str__(self):
84        try:
85            return next(self.__iter__())
86        except StopIteration:
87            return None
88
89    def __iter__(self):
90        for d in range(*self.parsed[3]):
91            for c in range(*self.parsed[2]):
92                for b in range(*self.parsed[1]):
93                    for a in range(*self.parsed[0]):
94                        yield "%i.%i.%i.%i" % (a,b,c,d)
95    def choice(self):
96        ip = []
97        for v in self.parsed:
98            ip.append(str(random.randint(v[0],v[1]-1)))
99        return ".".join(ip)
100
101    def __repr__(self):
102        return "Net(%r)" % self.repr
103    def __eq__(self, other):
104        if hasattr(other, "parsed"):
105            p2 = other.parsed
106        else:
107            p2,nm2 = self._parse_net(other)
108        return self.parsed == p2
109    def __contains__(self, other):
110        if hasattr(other, "parsed"):
111            p2 = other.parsed
112        else:
113            p2,nm2 = self._parse_net(other)
114        for (a1,b1),(a2,b2) in zip(self.parsed,p2):
115            if a1 > a2 or b1 < b2:
116                return False
117        return True
118    def __rcontains__(self, other):
119        return self in self.__class__(other)
120
121
122class OID(Gen):
123    name = "OID"
124    def __init__(self, oid):
125        self.oid = oid
126        self.cmpt = []
127        fmt = []
128        for i in oid.split("."):
129            if "-" in i:
130                fmt.append("%i")
131                self.cmpt.append(tuple(map(int, i.split("-"))))
132            else:
133                fmt.append(i)
134        self.fmt = ".".join(fmt)
135    def __repr__(self):
136        return "OID(%r)" % self.oid
137    def __iter__(self):
138        ii = [k[0] for k in self.cmpt]
139        while True:
140            yield self.fmt % tuple(ii)
141            i = 0
142            while True:
143                if i >= len(ii):
144                    raise StopIteration
145                if ii[i] < self.cmpt[i][1]:
146                    ii[i]+=1
147                    break
148                else:
149                    ii[i] = self.cmpt[i][0]
150                i += 1
151
152
153
154######################################
155## Packet abstract and base classes ##
156######################################
157
158class Packet_metaclass(type):
159    def __new__(cls, name, bases, dct):
160        if "fields_desc" in dct: # perform resolution of references to other packets
161            current_fld = dct["fields_desc"]
162            resolved_fld = []
163            for f in current_fld:
164                if isinstance(f, Packet_metaclass): # reference to another fields_desc
165                    for f2 in f.fields_desc:
166                        resolved_fld.append(f2)
167                else:
168                    resolved_fld.append(f)
169        else: # look for a fields_desc in parent classes
170            resolved_fld = None
171            for b in bases:
172                if hasattr(b,"fields_desc"):
173                    resolved_fld = b.fields_desc
174                    break
175
176        if resolved_fld: # perform default value replacements
177            final_fld = []
178            for f in resolved_fld:
179                if f.name in dct:
180                    f = f.copy()
181                    f.default = dct[f.name]
182                    del(dct[f.name])
183                final_fld.append(f)
184
185            dct["fields_desc"] = final_fld
186
187        if "__slots__" not in dct:
188            dct["__slots__"] = []
189        for attr in ["name", "overload_fields"]:
190            try:
191                dct["_%s" % attr] = dct.pop(attr)
192            except KeyError:
193                pass
194        newcls = super(Packet_metaclass, cls).__new__(cls, name, bases, dct)
195        newcls.__all_slots__ = set(
196            attr
197            for cls in newcls.__mro__ if hasattr(cls, "__slots__")
198            for attr in cls.__slots__
199        )
200
201        if hasattr(newcls, "aliastypes"):
202            newcls.aliastypes = [newcls] + newcls.aliastypes
203        else:
204            newcls.aliastypes = [newcls]
205
206        if hasattr(newcls,"register_variant"):
207            newcls.register_variant()
208        for f in newcls.fields_desc:
209            if hasattr(f, "register_owner"):
210                f.register_owner(newcls)
211        from scapy import config
212        config.conf.layers.register(newcls)
213        return newcls
214
215    def __getattr__(self, attr):
216        for k in self.fields_desc:
217            if k.name == attr:
218                return k
219        raise AttributeError(attr)
220
221    def __call__(cls, *args, **kargs):
222        if "dispatch_hook" in cls.__dict__:
223            try:
224                cls = cls.dispatch_hook(*args, **kargs)
225            except:
226                from scapy import config
227                if config.conf.debug_dissector:
228                    raise
229                cls = config.conf.raw_layer
230        i = cls.__new__(cls, cls.__name__, cls.__bases__, cls.__dict__)
231        i.__init__(*args, **kargs)
232        return i
233
234class Field_metaclass(type):
235    def __new__(cls, name, bases, dct):
236        if "__slots__" not in dct:
237            dct["__slots__"] = []
238        newcls = super(Field_metaclass, cls).__new__(cls, name, bases, dct)
239        return newcls
240
241class NewDefaultValues(Packet_metaclass):
242    """NewDefaultValues is deprecated (not needed anymore)
243
244    remove this:
245        __metaclass__ = NewDefaultValues
246    and it should still work.
247    """
248    def __new__(cls, name, bases, dct):
249        from scapy.error import log_loading
250        import traceback
251        try:
252            for tb in traceback.extract_stack()+[("??",-1,None,"")]:
253                f,l,_,line = tb
254                if line.startswith("class"):
255                    break
256        except:
257            f,l="??",-1
258            raise
259        log_loading.warning("Deprecated (no more needed) use of NewDefaultValues  (%s l. %i).", f, l)
260
261        return super(NewDefaultValues, cls).__new__(cls, name, bases, dct)
262
263class BasePacket(Gen):
264    __slots__ = []
265
266
267#############################
268## Packet list base class  ##
269#############################
270
271class BasePacketList(object):
272    __slots__ = []
273