1## This file is part of Scapy
2## See http://www.secdev.org/projects/scapy for more informations
3## Copyright (C) Philippe Biondi <phil@secdev.org>
4## Modified by Maxence Tury <maxence.tury@ssi.gouv.fr>
5## This program is published under a GPLv2 license
6
7"""
8ASN.1 (Abstract Syntax Notation One)
9"""
10
11from __future__ import absolute_import
12from __future__ import print_function
13import random
14from datetime import datetime
15from scapy.config import conf
16from scapy.error import Scapy_Exception, warning
17from scapy.volatile import RandField, RandIP, GeneralizedTime
18from scapy.utils import Enum_metaclass, EnumElement, binrepr
19from scapy.compat import plain_str, chb, raw, orb
20import scapy.modules.six as six
21from scapy.modules.six.moves import range
22
23class RandASN1Object(RandField):
24    def __init__(self, objlist=None):
25        self.objlist = [
26            x._asn1_obj
27            for x in six.itervalues(ASN1_Class_UNIVERSAL.__rdict__)
28            if hasattr(x, "_asn1_obj")
29        ] if objlist is None else objlist
30        self.chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
31    def _fix(self, n=0):
32        o = random.choice(self.objlist)
33        if issubclass(o, ASN1_INTEGER):
34            return o(int(random.gauss(0,1000)))
35        elif issubclass(o, ASN1_IPADDRESS):
36            z = RandIP()._fix()
37            return o(z)
38        elif issubclass(o, ASN1_GENERALIZED_TIME) or issubclass(o, ASN1_UTC_TIME):
39            z = GeneralizedTime()._fix()
40            return o(z)
41        elif issubclass(o, ASN1_STRING):
42            z = int(random.expovariate(0.05)+1)
43            return o("".join(random.choice(self.chars) for _ in range(z)))
44        elif issubclass(o, ASN1_SEQUENCE) and (n < 10):
45            z = int(random.expovariate(0.08)+1)
46            return o([self.__class__(objlist=self.objlist)._fix(n + 1)
47                      for _ in range(z)])
48        return ASN1_INTEGER(int(random.gauss(0,1000)))
49
50
51##############
52#### ASN1 ####
53##############
54
55class ASN1_Error(Scapy_Exception):
56    pass
57
58class ASN1_Encoding_Error(ASN1_Error):
59    pass
60
61class ASN1_Decoding_Error(ASN1_Error):
62    pass
63
64class ASN1_BadTag_Decoding_Error(ASN1_Decoding_Error):
65    pass
66
67
68
69class ASN1Codec(EnumElement):
70    def register_stem(cls, stem):
71        cls._stem = stem
72    def dec(cls, s, context=None):
73        return cls._stem.dec(s, context=context)
74    def safedec(cls, s, context=None):
75        return cls._stem.safedec(s, context=context)
76    def get_stem(cls):
77        return cls.stem
78
79
80class ASN1_Codecs_metaclass(Enum_metaclass):
81    element_class = ASN1Codec
82
83class ASN1_Codecs(six.with_metaclass(ASN1_Codecs_metaclass)):
84    BER = 1
85    DER = 2
86    PER = 3
87    CER = 4
88    LWER = 5
89    BACnet = 6
90    OER = 7
91    SER = 8
92    XER = 9
93
94class ASN1Tag(EnumElement):
95    def __init__(self, key, value, context=None, codec=None):
96        EnumElement.__init__(self, key, value)
97        self._context = context
98        if codec == None:
99            codec = {}
100        self._codec = codec
101    def clone(self): # /!\ not a real deep copy. self.codec is shared
102        return self.__class__(self._key, self._value, self._context, self._codec)
103    def register_asn1_object(self, asn1obj):
104        self._asn1_obj = asn1obj
105    def asn1_object(self, val):
106        if hasattr(self,"_asn1_obj"):
107            return self._asn1_obj(val)
108        raise ASN1_Error("%r does not have any assigned ASN1 object" % self)
109    def register(self, codecnum, codec):
110        self._codec[codecnum] = codec
111    def get_codec(self, codec):
112        try:
113            c = self._codec[codec]
114        except KeyError as msg:
115            raise ASN1_Error("Codec %r not found for tag %r" % (codec, self))
116        return c
117
118class ASN1_Class_metaclass(Enum_metaclass):
119    element_class = ASN1Tag
120    def __new__(cls, name, bases, dct): # XXX factorise a bit with Enum_metaclass.__new__()
121        for b in bases:
122            for k,v in six.iteritems(b.__dict__):
123                if k not in dct and isinstance(v,ASN1Tag):
124                    dct[k] = v.clone()
125
126        rdict = {}
127        for k,v in six.iteritems(dct):
128            if isinstance(v, int):
129                v = ASN1Tag(k,v)
130                dct[k] = v
131                rdict[v] = v
132            elif isinstance(v, ASN1Tag):
133                rdict[v] = v
134        dct["__rdict__"] = rdict
135
136        cls = type.__new__(cls, name, bases, dct)
137        for v in cls.__dict__.values():
138            if isinstance(v, ASN1Tag):
139                v.context = cls # overwrite ASN1Tag contexts, even cloned ones
140        return cls
141
142
143class ASN1_Class(six.with_metaclass(ASN1_Class_metaclass)):
144    pass
145
146class ASN1_Class_UNIVERSAL(ASN1_Class):
147    name = "UNIVERSAL"
148    ERROR = -3
149    RAW = -2
150    NONE = -1
151    ANY = 0
152    BOOLEAN = 1
153    INTEGER = 2
154    BIT_STRING = 3
155    STRING = 4
156    NULL = 5
157    OID = 6
158    OBJECT_DESCRIPTOR = 7
159    EXTERNAL = 8
160    REAL = 9
161    ENUMERATED = 10
162    EMBEDDED_PDF = 11
163    UTF8_STRING = 12
164    RELATIVE_OID = 13
165    SEQUENCE = 16|0x20          # constructed encoding
166    SET = 17|0x20               # constructed encoding
167    NUMERIC_STRING = 18
168    PRINTABLE_STRING = 19
169    T61_STRING = 20             # aka TELETEX_STRING
170    VIDEOTEX_STRING = 21
171    IA5_STRING = 22
172    UTC_TIME = 23
173    GENERALIZED_TIME = 24
174    GRAPHIC_STRING = 25
175    ISO646_STRING = 26          # aka VISIBLE_STRING
176    GENERAL_STRING = 27
177    UNIVERSAL_STRING = 28
178    CHAR_STRING = 29
179    BMP_STRING = 30
180    IPADDRESS = 0|0x40          # application-specific encoding
181    COUNTER32 = 1|0x40          # application-specific encoding
182    GAUGE32 = 2|0x40            # application-specific encoding
183    TIME_TICKS = 3|0x40         # application-specific encoding
184
185
186class ASN1_Object_metaclass(type):
187    def __new__(cls, name, bases, dct):
188        c = super(ASN1_Object_metaclass, cls).__new__(cls, name, bases, dct)
189        try:
190            c.tag.register_asn1_object(c)
191        except:
192            warning("Error registering %r for %r" % (c.tag, c.codec))
193        return c
194
195class ASN1_Object(six.with_metaclass(ASN1_Object_metaclass)):
196    tag = ASN1_Class_UNIVERSAL.ANY
197    def __init__(self, val):
198        self.val = val
199    def enc(self, codec):
200        return self.tag.get_codec(codec).enc(self.val)
201    def __repr__(self):
202        return "<%s[%r]>" % (self.__dict__.get("name", self.__class__.__name__), self.val)
203    def __str__(self):
204        return self.enc(conf.ASN1_default_codec)
205    def __bytes__(self):
206        return self.enc(conf.ASN1_default_codec)
207    def strshow(self, lvl=0):
208        return ("  "*lvl)+repr(self)+"\n"
209    def show(self, lvl=0):
210        print(self.strshow(lvl))
211    def __eq__(self, other):
212        return self.val == other
213    def __lt__(self, other):
214        return self.val < other
215    def __le__(self, other):
216        return self.val <= other
217    def __gt__(self, other):
218        return self.val > other
219    def __ge__(self, other):
220        return self.val >= other
221    def __ne__(self, other):
222        return self.val != other
223
224
225#######################
226####  ASN1 objects ####
227#######################
228
229# on the whole, we order the classes by ASN1_Class_UNIVERSAL tag value
230
231class ASN1_DECODING_ERROR(ASN1_Object):
232    tag = ASN1_Class_UNIVERSAL.ERROR
233    def __init__(self, val, exc=None):
234        ASN1_Object.__init__(self, val)
235        self.exc = exc
236    def __repr__(self):
237        return "<%s[%r]{{%r}}>" % (self.__dict__.get("name", self.__class__.__name__),
238                                   self.val, self.exc.args[0])
239    def enc(self, codec):
240        if isinstance(self.val, ASN1_Object):
241            return self.val.enc(codec)
242        return self.val
243
244class ASN1_force(ASN1_Object):
245    tag = ASN1_Class_UNIVERSAL.RAW
246    def enc(self, codec):
247        if isinstance(self.val, ASN1_Object):
248            return self.val.enc(codec)
249        return self.val
250
251class ASN1_BADTAG(ASN1_force):
252    pass
253
254class ASN1_INTEGER(ASN1_Object):
255    tag = ASN1_Class_UNIVERSAL.INTEGER
256    def __repr__(self):
257        h = hex(self.val)
258        if h[-1] == "L":
259            h = h[:-1]
260        # cut at 22 because with leading '0x', x509 serials should be < 23
261        if len(h) > 22:
262            h = h[:12] + "..." + h[-10:]
263        r = repr(self.val)
264        if len(r) > 20:
265            r = r[:10] + "..." + r[-10:]
266        return h + " <%s[%s]>" % (self.__dict__.get("name", self.__class__.__name__), r)
267
268
269class ASN1_BOOLEAN(ASN1_INTEGER):
270    tag = ASN1_Class_UNIVERSAL.BOOLEAN
271    # BER: 0 means False, anything else means True
272    def __repr__(self):
273        return '%s %s' % (not (self.val==0), ASN1_Object.__repr__(self))
274
275
276class ASN1_BIT_STRING(ASN1_Object):
277    """
278    /!\ ASN1_BIT_STRING values are bit strings like "011101".
279    /!\ A zero-bit padded readable string is provided nonetheless,
280    /!\ which is also output when __str__ is called.
281    """
282    tag = ASN1_Class_UNIVERSAL.BIT_STRING
283    def __init__(self, val, readable=False):
284        if not readable:
285            self.val = val
286        else:
287            self.val_readable = val
288    def __setattr__(self, name, value):
289        str_value = None
290        if isinstance(value, str):
291            str_value = value
292            value = raw(value)
293        if name == "val_readable":
294            if isinstance(value, bytes):
295                val = b"".join(binrepr(orb(x)).zfill(8).encode("utf8") for x in value)
296            else:
297                val = "<invalid val_readable>"
298            super(ASN1_Object, self).__setattr__("val", val)
299            super(ASN1_Object, self).__setattr__(name, value)
300            super(ASN1_Object, self).__setattr__("unused_bits", 0)
301        elif name == "val":
302            if not str_value:
303                str_value = plain_str(value)
304            if isinstance(value, bytes):
305                if any(c for c in str_value if c not in ["0", "1"]):
306                    print("Invalid operation: 'val' is not a valid bit string.")
307                    return
308                else:
309                    if len(value) % 8 == 0:
310                        unused_bits = 0
311                    else:
312                        unused_bits = 8 - (len(value) % 8)
313                    padded_value = str_value + ("0" * unused_bits)
314                    bytes_arr = zip(*[iter(padded_value)]*8)
315                    val_readable = b"".join(chb(int("".join(x),2)) for x in bytes_arr)
316            else:
317                val_readable = "<invalid val>"
318                unused_bits = 0
319            super(ASN1_Object, self).__setattr__("val_readable", val_readable)
320            super(ASN1_Object, self).__setattr__(name, value)
321            super(ASN1_Object, self).__setattr__("unused_bits", unused_bits)
322        elif name == "unused_bits":
323            print("Invalid operation: unused_bits rewriting is not supported.")
324        else:
325            super(ASN1_Object, self).__setattr__(name, value)
326    def __repr__(self):
327        if len(self.val) <= 16:
328            v = plain_str(self.val)
329            return "<%s[%s] (%d unused bit%s)>" % (self.__dict__.get("name", self.__class__.__name__), v, self.unused_bits, "s" if self.unused_bits>1 else "")
330        else:
331            s = self.val_readable
332            if len(s) > 20:
333                s = s[:10] + b"..." + s[-10:]
334            v = plain_str(self.val)
335            return "<%s[%s] (%d unused bit%s)>" % (self.__dict__.get("name", self.__class__.__name__), v, self.unused_bits, "s" if self.unused_bits>1 else "")
336    def __str__(self):
337        return self.val_readable
338    def __bytes__(self):
339        return self.val_readable
340
341class ASN1_STRING(ASN1_Object):
342    tag = ASN1_Class_UNIVERSAL.STRING
343
344class ASN1_NULL(ASN1_Object):
345    tag = ASN1_Class_UNIVERSAL.NULL
346    def __repr__(self):
347        return ASN1_Object.__repr__(self)
348
349class ASN1_OID(ASN1_Object):
350    tag = ASN1_Class_UNIVERSAL.OID
351    def __init__(self, val):
352        val = conf.mib._oid(plain_str(val))
353        ASN1_Object.__init__(self, val)
354        self.oidname = conf.mib._oidname(val)
355    def __repr__(self):
356        return "<%s[%r]>" % (self.__dict__.get("name", self.__class__.__name__), self.oidname)
357
358class ASN1_ENUMERATED(ASN1_INTEGER):
359    tag = ASN1_Class_UNIVERSAL.ENUMERATED
360
361class ASN1_UTF8_STRING(ASN1_STRING):
362    tag = ASN1_Class_UNIVERSAL.UTF8_STRING
363
364class ASN1_NUMERIC_STRING(ASN1_STRING):
365    tag = ASN1_Class_UNIVERSAL.NUMERIC_STRING
366
367class ASN1_PRINTABLE_STRING(ASN1_STRING):
368    tag = ASN1_Class_UNIVERSAL.PRINTABLE_STRING
369
370class ASN1_T61_STRING(ASN1_STRING):
371    tag = ASN1_Class_UNIVERSAL.T61_STRING
372
373class ASN1_VIDEOTEX_STRING(ASN1_STRING):
374    tag = ASN1_Class_UNIVERSAL.VIDEOTEX_STRING
375
376class ASN1_IA5_STRING(ASN1_STRING):
377    tag = ASN1_Class_UNIVERSAL.IA5_STRING
378
379class ASN1_UTC_TIME(ASN1_STRING):
380    tag = ASN1_Class_UNIVERSAL.UTC_TIME
381    def __init__(self, val):
382        super(ASN1_UTC_TIME, self).__init__(val)
383    def __setattr__(self, name, value):
384        if isinstance(value, bytes):
385            value = plain_str(value)
386        if name == "val":
387            pretty_time = None
388            if (isinstance(value, str) and
389                len(value) == 13 and value[-1] == "Z"):
390                dt = datetime.strptime(value[:-1], "%y%m%d%H%M%S")
391                pretty_time = dt.strftime("%b %d %H:%M:%S %Y GMT")
392            else:
393                pretty_time = "%s [invalid utc_time]" % value
394            super(ASN1_UTC_TIME, self).__setattr__("pretty_time", pretty_time)
395            super(ASN1_UTC_TIME, self).__setattr__(name, value)
396        elif name == "pretty_time":
397            print("Invalid operation: pretty_time rewriting is not supported.")
398        else:
399            super(ASN1_UTC_TIME, self).__setattr__(name, value)
400    def __repr__(self):
401        return "%s %s" % (self.pretty_time, ASN1_STRING.__repr__(self))
402
403class ASN1_GENERALIZED_TIME(ASN1_STRING):
404    tag = ASN1_Class_UNIVERSAL.GENERALIZED_TIME
405    def __init__(self, val):
406        super(ASN1_GENERALIZED_TIME, self).__init__(val)
407    def __setattr__(self, name, value):
408        if isinstance(value, bytes):
409            value = plain_str(value)
410        if name == "val":
411            pretty_time = None
412            if (isinstance(value, str) and
413                len(value) == 15 and value[-1] == "Z"):
414                dt = datetime.strptime(value[:-1], "%Y%m%d%H%M%S")
415                pretty_time = dt.strftime("%b %d %H:%M:%S %Y GMT")
416            else:
417                pretty_time = "%s [invalid generalized_time]" % value
418            super(ASN1_GENERALIZED_TIME, self).__setattr__("pretty_time", pretty_time)
419            super(ASN1_GENERALIZED_TIME, self).__setattr__(name, value)
420        elif name == "pretty_time":
421            print("Invalid operation: pretty_time rewriting is not supported.")
422        else:
423            super(ASN1_GENERALIZED_TIME, self).__setattr__(name, value)
424    def __repr__(self):
425        return "%s %s" % (self.pretty_time, ASN1_STRING.__repr__(self))
426
427class ASN1_ISO646_STRING(ASN1_STRING):
428    tag = ASN1_Class_UNIVERSAL.ISO646_STRING
429
430class ASN1_UNIVERSAL_STRING(ASN1_STRING):
431    tag = ASN1_Class_UNIVERSAL.UNIVERSAL_STRING
432
433class ASN1_BMP_STRING(ASN1_STRING):
434    tag = ASN1_Class_UNIVERSAL.BMP_STRING
435
436class ASN1_SEQUENCE(ASN1_Object):
437    tag = ASN1_Class_UNIVERSAL.SEQUENCE
438    def strshow(self, lvl=0):
439        s = ("  "*lvl)+("# %s:" % self.__class__.__name__)+"\n"
440        for o in self.val:
441            s += o.strshow(lvl=lvl+1)
442        return s
443
444class ASN1_SET(ASN1_SEQUENCE):
445    tag = ASN1_Class_UNIVERSAL.SET
446
447class ASN1_IPADDRESS(ASN1_STRING):
448    tag = ASN1_Class_UNIVERSAL.IPADDRESS
449
450class ASN1_COUNTER32(ASN1_INTEGER):
451    tag = ASN1_Class_UNIVERSAL.COUNTER32
452
453class ASN1_GAUGE32(ASN1_INTEGER):
454    tag = ASN1_Class_UNIVERSAL.GAUGE32
455
456class ASN1_TIME_TICKS(ASN1_INTEGER):
457    tag = ASN1_Class_UNIVERSAL.TIME_TICKS
458
459
460conf.ASN1_default_codec = ASN1_Codecs.BER
461