1## This file is part of Scapy
2## See http://www.secdev.org/projects/scapy for more informations
3## Copyright (C) Philippe Biondi <phil@secdev.org>
4## Enhanced by Maxence Tury <maxence.tury@ssi.gouv.fr>
5## This program is published under a GPLv2 license
6
7"""
8Classes that implement ASN.1 data structures.
9"""
10
11from __future__ import absolute_import
12from scapy.asn1.asn1 import *
13from scapy.asn1.ber import *
14from scapy.asn1.mib import *
15from scapy.volatile import *
16from scapy.compat import *
17from scapy.base_classes import BasePacket
18from scapy.utils import binrepr
19from scapy import packet
20from functools import reduce
21import scapy.modules.six as six
22from scapy.modules.six.moves import range
23
24class ASN1F_badsequence(Exception):
25    pass
26
27class ASN1F_element(object):
28    pass
29
30
31##########################
32#### Basic ASN1 Field ####
33##########################
34
35class ASN1F_field(ASN1F_element):
36    holds_packets = 0
37    islist = 0
38    ASN1_tag = ASN1_Class_UNIVERSAL.ANY
39    context = ASN1_Class_UNIVERSAL
40
41    def __init__(self, name, default, context=None,
42                 implicit_tag=None, explicit_tag=None,
43                 flexible_tag=False):
44        self.context = context
45        self.name = name
46        if default is None:
47            self.default = None
48        elif isinstance(default, ASN1_NULL):
49            self.default = default
50        else:
51            self.default = self.ASN1_tag.asn1_object(default)
52        self.flexible_tag = flexible_tag
53        if (implicit_tag is not None) and (explicit_tag is not None):
54            err_msg = "field cannot be both implicitly and explicitly tagged"
55            raise ASN1_Error(err_msg)
56        self.implicit_tag = implicit_tag
57        self.explicit_tag = explicit_tag
58        # network_tag gets useful for ASN1F_CHOICE
59        self.network_tag = implicit_tag or explicit_tag or self.ASN1_tag
60
61    def i2repr(self, pkt, x):
62        return repr(x)
63    def i2h(self, pkt, x):
64        return x
65    def any2i(self, pkt, x):
66        return x
67    def m2i(self, pkt, s):
68        """
69        The good thing about safedec is that it may still decode ASN1
70        even if there is a mismatch between the expected tag (self.ASN1_tag)
71        and the actual tag; the decoded ASN1 object will simply be put
72        into an ASN1_BADTAG object. However, safedec prevents the raising of
73        exceptions needed for ASN1F_optional processing.
74        Thus we use 'flexible_tag', which should be False with ASN1F_optional.
75
76        Regarding other fields, we might need to know whether encoding went
77        as expected or not. Noticeably, input methods from cert.py expect
78        certain exceptions to be raised. Hence default flexible_tag is False.
79        """
80        diff_tag, s = BER_tagging_dec(s, hidden_tag=self.ASN1_tag,
81                                      implicit_tag=self.implicit_tag,
82                                      explicit_tag=self.explicit_tag,
83                                      safe=self.flexible_tag)
84        if diff_tag is not None:
85            # this implies that flexible_tag was True
86            if self.implicit_tag is not None:
87                self.implicit_tag = diff_tag
88            elif self.explicit_tag is not None:
89                self.explicit_tag = diff_tag
90        codec = self.ASN1_tag.get_codec(pkt.ASN1_codec)
91        if self.flexible_tag:
92            return codec.safedec(s, context=self.context)
93        else:
94            return codec.dec(s, context=self.context)
95    def i2m(self, pkt, x):
96        if x is None:
97            return b""
98        if isinstance(x, ASN1_Object):
99            if ( self.ASN1_tag == ASN1_Class_UNIVERSAL.ANY
100                 or x.tag == ASN1_Class_UNIVERSAL.RAW
101                 or x.tag == ASN1_Class_UNIVERSAL.ERROR
102                 or self.ASN1_tag == x.tag ):
103                s = x.enc(pkt.ASN1_codec)
104            else:
105                raise ASN1_Error("Encoding Error: got %r instead of an %r for field [%s]" % (x, self.ASN1_tag, self.name))
106        else:
107            s = self.ASN1_tag.get_codec(pkt.ASN1_codec).enc(x)
108        return BER_tagging_enc(s, implicit_tag=self.implicit_tag,
109                               explicit_tag=self.explicit_tag)
110    def extract_packet(self, cls, s):
111        if len(s) > 0:
112            try:
113                c = cls(s)
114            except ASN1F_badsequence:
115                c = packet.Raw(s)
116            cpad = c.getlayer(packet.Raw)
117            s = b""
118            if cpad is not None:
119                s = cpad.load
120                del(cpad.underlayer.payload)
121            return c,s
122        else:
123            return None,s
124
125    def build(self, pkt):
126        return self.i2m(pkt, getattr(pkt, self.name))
127    def dissect(self, pkt, s):
128        v,s = self.m2i(pkt, s)
129        self.set_val(pkt, v)
130        return s
131
132    def do_copy(self, x):
133        if hasattr(x, "copy"):
134            return x.copy()
135        if isinstance(x, list):
136            x = x[:]
137            for i in range(len(x)):
138                if isinstance(x[i], BasePacket):
139                    x[i] = x[i].copy()
140        return x
141    def set_val(self, pkt, val):
142        setattr(pkt, self.name, val)
143    def is_empty(self, pkt):
144        return getattr(pkt, self.name) is None
145    def get_fields_list(self):
146        return [self]
147
148    def __hash__(self):
149        return hash(self.name)
150    def __str__(self):
151        return repr(self)
152    def randval(self):
153        return RandInt()
154
155
156############################
157#### Simple ASN1 Fields ####
158############################
159
160class ASN1F_BOOLEAN(ASN1F_field):
161    ASN1_tag = ASN1_Class_UNIVERSAL.BOOLEAN
162    def randval(self):
163        return RandChoice(True, False)
164
165class ASN1F_INTEGER(ASN1F_field):
166    ASN1_tag = ASN1_Class_UNIVERSAL.INTEGER
167    def randval(self):
168        return RandNum(-2**64, 2**64-1)
169
170class ASN1F_enum_INTEGER(ASN1F_INTEGER):
171    def __init__(self, name, default, enum, context=None,
172                 implicit_tag=None, explicit_tag=None):
173        ASN1F_INTEGER.__init__(self, name, default, context=context,
174                               implicit_tag=implicit_tag,
175                               explicit_tag=explicit_tag)
176        i2s = self.i2s = {}
177        s2i = self.s2i = {}
178        if isinstance(enum, list):
179            keys = range(len(enum))
180        else:
181            keys = list(enum)
182        if any(isinstance(x, six.string_types) for x in keys):
183            i2s, s2i = s2i, i2s
184        for k in keys:
185            i2s[k] = enum[k]
186            s2i[enum[k]] = k
187    def i2m(self, pkt, s):
188        if isinstance(s, str):
189            s = self.s2i.get(s)
190        return super(ASN1F_enum_INTEGER, self).i2m(pkt, s)
191    def i2repr(self, pkt, x):
192        if x is not None and isinstance(x, ASN1_INTEGER):
193            r = self.i2s.get(x.val)
194            if r:
195                return "'%s' %s" % (r, repr(x))
196        return repr(x)
197
198class ASN1F_BIT_STRING(ASN1F_field):
199    ASN1_tag = ASN1_Class_UNIVERSAL.BIT_STRING
200    def __init__(self, name, default, default_readable=True, context=None,
201                 implicit_tag=None, explicit_tag=None):
202        if default is not None and default_readable:
203            default = b"".join(binrepr(orb(x)).zfill(8).encode("utf8") for x in default)
204        ASN1F_field.__init__(self, name, default, context=context,
205                             implicit_tag=implicit_tag,
206                             explicit_tag=explicit_tag)
207    def randval(self):
208        return RandString(RandNum(0, 1000))
209
210class ASN1F_STRING(ASN1F_field):
211    ASN1_tag = ASN1_Class_UNIVERSAL.STRING
212    def randval(self):
213        return RandString(RandNum(0, 1000))
214
215class ASN1F_NULL(ASN1F_INTEGER):
216    ASN1_tag = ASN1_Class_UNIVERSAL.NULL
217
218class ASN1F_OID(ASN1F_field):
219    ASN1_tag = ASN1_Class_UNIVERSAL.OID
220    def randval(self):
221        return RandOID()
222
223class ASN1F_ENUMERATED(ASN1F_enum_INTEGER):
224    ASN1_tag = ASN1_Class_UNIVERSAL.ENUMERATED
225
226class ASN1F_UTF8_STRING(ASN1F_STRING):
227    ASN1_tag = ASN1_Class_UNIVERSAL.UTF8_STRING
228
229class ASN1F_NUMERIC_STRING(ASN1F_STRING):
230    ASN1_tag = ASN1_Class_UNIVERSAL.NUMERIC_STRING
231
232class ASN1F_PRINTABLE_STRING(ASN1F_STRING):
233    ASN1_tag = ASN1_Class_UNIVERSAL.PRINTABLE_STRING
234
235class ASN1F_T61_STRING(ASN1F_STRING):
236    ASN1_tag = ASN1_Class_UNIVERSAL.T61_STRING
237
238class ASN1F_VIDEOTEX_STRING(ASN1F_STRING):
239    ASN1_tag = ASN1_Class_UNIVERSAL.VIDEOTEX_STRING
240
241class ASN1F_IA5_STRING(ASN1F_STRING):
242    ASN1_tag = ASN1_Class_UNIVERSAL.IA5_STRING
243
244class ASN1F_UTC_TIME(ASN1F_STRING):
245    ASN1_tag = ASN1_Class_UNIVERSAL.UTC_TIME
246    def randval(self):
247        return GeneralizedTime()
248
249class ASN1F_GENERALIZED_TIME(ASN1F_STRING):
250    ASN1_tag = ASN1_Class_UNIVERSAL.GENERALIZED_TIME
251    def randval(self):
252        return GeneralizedTime()
253
254class ASN1F_ISO646_STRING(ASN1F_STRING):
255    ASN1_tag = ASN1_Class_UNIVERSAL.ISO646_STRING
256
257class ASN1F_UNIVERSAL_STRING(ASN1F_STRING):
258    ASN1_tag = ASN1_Class_UNIVERSAL.UNIVERSAL_STRING
259
260class ASN1F_BMP_STRING(ASN1F_STRING):
261    ASN1_tag = ASN1_Class_UNIVERSAL.BMP_STRING
262
263class ASN1F_SEQUENCE(ASN1F_field):
264# Here is how you could decode a SEQUENCE
265# with an unknown, private high-tag prefix :
266# class PrivSeq(ASN1_Packet):
267#     ASN1_codec = ASN1_Codecs.BER
268#     ASN1_root = ASN1F_SEQUENCE(
269#                       <asn1 field #0>,
270#                       ...
271#                       <asn1 field #N>,
272#                       explicit_tag=0,
273#                       flexible_tag=True)
274# Because we use flexible_tag, the value of the explicit_tag does not matter.
275    ASN1_tag = ASN1_Class_UNIVERSAL.SEQUENCE
276    holds_packets = 1
277    def __init__(self, *seq, **kwargs):
278        name = "dummy_seq_name"
279        default = [field.default for field in seq]
280        for kwarg in ["context", "implicit_tag",
281                      "explicit_tag", "flexible_tag"]:
282            if kwarg in kwargs:
283                setattr(self, kwarg, kwargs[kwarg])
284            else:
285                setattr(self, kwarg, None)
286        ASN1F_field.__init__(self, name, default, context=self.context,
287                             implicit_tag=self.implicit_tag,
288                             explicit_tag=self.explicit_tag,
289                             flexible_tag=self.flexible_tag)
290        self.seq = seq
291        self.islist = len(seq) > 1
292    def __repr__(self):
293        return "<%s%r>" % (self.__class__.__name__, self.seq)
294    def is_empty(self, pkt):
295        for f in self.seq:
296            if not f.is_empty(pkt):
297                return False
298        return True
299    def get_fields_list(self):
300        return reduce(lambda x,y: x+y.get_fields_list(), self.seq, [])
301    def m2i(self, pkt, s):
302        """
303        ASN1F_SEQUENCE behaves transparently, with nested ASN1_objects being
304        dissected one by one. Because we use obj.dissect (see loop below)
305        instead of obj.m2i (as we trust dissect to do the appropriate set_vals)
306        we do not directly retrieve the list of nested objects.
307        Thus m2i returns an empty list (along with the proper remainder).
308        It is discarded by dissect() and should not be missed elsewhere.
309        """
310        diff_tag, s = BER_tagging_dec(s, hidden_tag=self.ASN1_tag,
311                                      implicit_tag=self.implicit_tag,
312                                      explicit_tag=self.explicit_tag,
313                                      safe=self.flexible_tag)
314        if diff_tag is not None:
315            if self.implicit_tag is not None:
316                self.implicit_tag = diff_tag
317            elif self.explicit_tag is not None:
318                self.explicit_tag = diff_tag
319        codec = self.ASN1_tag.get_codec(pkt.ASN1_codec)
320        i,s,remain = codec.check_type_check_len(s)
321        if len(s) == 0:
322            for obj in self.seq:
323                obj.set_val(pkt, None)
324        else:
325            for obj in self.seq:
326                try:
327                    s = obj.dissect(pkt, s)
328                except ASN1F_badsequence as e:
329                    break
330            if len(s) > 0:
331                raise BER_Decoding_Error("unexpected remainder", remaining=s)
332        return [], remain
333    def dissect(self, pkt, s):
334        _,x = self.m2i(pkt, s)
335        return x
336    def build(self, pkt):
337        s = reduce(lambda x,y: x+y.build(pkt), self.seq, b"")
338        return self.i2m(pkt, s)
339
340class ASN1F_SET(ASN1F_SEQUENCE):
341    ASN1_tag = ASN1_Class_UNIVERSAL.SET
342
343class ASN1F_SEQUENCE_OF(ASN1F_field):
344    ASN1_tag = ASN1_Class_UNIVERSAL.SEQUENCE
345    holds_packets = 1
346    islist = 1
347    def __init__(self, name, default, cls, context=None,
348                 implicit_tag=None, explicit_tag=None):
349        self.cls = cls
350        ASN1F_field.__init__(self, name, None, context=context,
351                        implicit_tag=implicit_tag, explicit_tag=explicit_tag)
352        self.default = default
353    def is_empty(self, pkt):
354        return ASN1F_field.is_empty(self, pkt)
355    def m2i(self, pkt, s):
356        diff_tag, s = BER_tagging_dec(s, hidden_tag=self.ASN1_tag,
357                                      implicit_tag=self.implicit_tag,
358                                      explicit_tag=self.explicit_tag,
359                                      safe=self.flexible_tag)
360        if diff_tag is not None:
361            if self.implicit_tag is not None:
362                self.implicit_tag = diff_tag
363            elif self.explicit_tag is not None:
364                self.explicit_tag = diff_tag
365        codec = self.ASN1_tag.get_codec(pkt.ASN1_codec)
366        i,s,remain = codec.check_type_check_len(s)
367        lst = []
368        while s:
369            c,s = self.extract_packet(self.cls, s)
370            lst.append(c)
371        if len(s) > 0:
372            raise BER_Decoding_Error("unexpected remainder", remaining=s)
373        return lst, remain
374    def build(self, pkt):
375        val = getattr(pkt, self.name)
376        if isinstance(val, ASN1_Object) and val.tag==ASN1_Class_UNIVERSAL.RAW:
377            s = val
378        elif val is None:
379            s = b""
380        else:
381            s = b"".join(raw(i) for i in val)
382        return self.i2m(pkt, s)
383
384    def randval(self):
385        return packet.fuzz(self.cls())
386    def __repr__(self):
387        return "<%s %s>" % (self.__class__.__name__, self.name)
388
389class ASN1F_SET_OF(ASN1F_SEQUENCE_OF):
390    ASN1_tag = ASN1_Class_UNIVERSAL.SET
391
392class ASN1F_IPADDRESS(ASN1F_STRING):
393    ASN1_tag = ASN1_Class_UNIVERSAL.IPADDRESS
394
395class ASN1F_TIME_TICKS(ASN1F_INTEGER):
396    ASN1_tag = ASN1_Class_UNIVERSAL.TIME_TICKS
397
398
399#############################
400#### Complex ASN1 Fields ####
401#############################
402
403class ASN1F_optional(ASN1F_element):
404    def __init__(self, field):
405        field.flexible_tag = False
406        self._field = field
407    def __getattr__(self, attr):
408        return getattr(self._field, attr)
409    def m2i(self, pkt, s):
410        try:
411            return self._field.m2i(pkt, s)
412        except (ASN1_Error, ASN1F_badsequence, BER_Decoding_Error):
413            # ASN1_Error may be raised by ASN1F_CHOICE
414            return None, s
415    def dissect(self, pkt, s):
416        try:
417            return self._field.dissect(pkt, s)
418        except (ASN1_Error, ASN1F_badsequence, BER_Decoding_Error):
419            self._field.set_val(pkt, None)
420            return s
421    def build(self, pkt):
422        if self._field.is_empty(pkt):
423            return b""
424        return self._field.build(pkt)
425    def any2i(self, pkt, x):
426        return self._field.any2i(pkt, x)
427    def i2repr(self, pkt, x):
428        return self._field.i2repr(pkt, x)
429
430class ASN1F_CHOICE(ASN1F_field):
431    """
432    Multiple types are allowed: ASN1_Packet, ASN1F_field and ASN1F_PACKET(),
433    See layers/x509.py for examples.
434    Other ASN1F_field instances than ASN1F_PACKET instances must not be used.
435    """
436    holds_packets = 1
437    ASN1_tag = ASN1_Class_UNIVERSAL.ANY
438    def __init__(self, name, default, *args, **kwargs):
439        if "implicit_tag" in kwargs:
440            err_msg = "ASN1F_CHOICE has been called with an implicit_tag"
441            raise ASN1_Error(err_msg)
442        self.implicit_tag = None
443        for kwarg in ["context", "explicit_tag"]:
444            if kwarg in kwargs:
445                setattr(self, kwarg, kwargs[kwarg])
446            else:
447                setattr(self, kwarg, None)
448        ASN1F_field.__init__(self, name, None, context=self.context,
449                             explicit_tag=self.explicit_tag)
450        self.default = default
451        self.current_choice = None
452        self.choices = {}
453        self.pktchoices = {}
454        for p in args:
455            if hasattr(p, "ASN1_root"):     # should be ASN1_Packet
456                if hasattr(p.ASN1_root, "choices"):
457                    for k,v in six.iteritems(p.ASN1_root.choices):
458                        self.choices[k] = v         # ASN1F_CHOICE recursion
459                else:
460                    self.choices[p.ASN1_root.network_tag] = p
461            elif hasattr(p, "ASN1_tag"):
462                if isinstance(p, type):         # should be ASN1F_field class
463                    self.choices[p.ASN1_tag] = p
464                else:                       # should be ASN1F_PACKET instance
465                    self.choices[p.network_tag] = p
466                    self.pktchoices[hash(p.cls)] = (p.implicit_tag, p.explicit_tag)
467            else:
468                raise ASN1_Error("ASN1F_CHOICE: no tag found for one field")
469    def m2i(self, pkt, s):
470        """
471        First we have to retrieve the appropriate choice.
472        Then we extract the field/packet, according to this choice.
473        """
474        if len(s) == 0:
475            raise ASN1_Error("ASN1F_CHOICE: got empty string")
476        _,s = BER_tagging_dec(s, hidden_tag=self.ASN1_tag,
477                              explicit_tag=self.explicit_tag)
478        tag,_ = BER_id_dec(s)
479        if tag not in self.choices:
480            if self.flexible_tag:
481                choice = ASN1F_field
482            else:
483                raise ASN1_Error("ASN1F_CHOICE: unexpected field")
484        else:
485            choice = self.choices[tag]
486        if hasattr(choice, "ASN1_root"):
487            # we don't want to import ASN1_Packet in this module...
488            return self.extract_packet(choice, s)
489        elif isinstance(choice, type):
490            #XXX find a way not to instantiate the ASN1F_field
491            return choice(self.name, b"").m2i(pkt, s)
492        else:
493            #XXX check properly if this is an ASN1F_PACKET
494            return choice.m2i(pkt, s)
495    def i2m(self, pkt, x):
496        if x is None:
497            s = b""
498        else:
499            s = raw(x)
500            if hash(type(x)) in self.pktchoices:
501                imp, exp = self.pktchoices[hash(type(x))]
502                s = BER_tagging_enc(s, implicit_tag=imp,
503                                    explicit_tag=exp)
504        return BER_tagging_enc(s, explicit_tag=self.explicit_tag)
505    def randval(self):
506        randchoices = []
507        for p in six.itervalues(self.choices):
508            if hasattr(p, "ASN1_root"):   # should be ASN1_Packet class
509                randchoices.append(packet.fuzz(p()))
510            elif hasattr(p, "ASN1_tag"):
511                if isinstance(p, type):       # should be (basic) ASN1F_field class
512                    randchoices.append(p("dummy", None).randval())
513                else:                     # should be ASN1F_PACKET instance
514                    randchoices.append(p.randval())
515        return RandChoice(*randchoices)
516
517class ASN1F_PACKET(ASN1F_field):
518    holds_packets = 1
519    def __init__(self, name, default, cls, context=None,
520                 implicit_tag=None, explicit_tag=None):
521        self.cls = cls
522        ASN1F_field.__init__(self, name, None, context=context,
523                        implicit_tag=implicit_tag, explicit_tag=explicit_tag)
524        if cls.ASN1_root.ASN1_tag == ASN1_Class_UNIVERSAL.SEQUENCE:
525            if implicit_tag is None and explicit_tag is None:
526                self.network_tag = 16|0x20
527        self.default = default
528    def m2i(self, pkt, s):
529        diff_tag, s = BER_tagging_dec(s, hidden_tag=self.cls.ASN1_root.ASN1_tag,
530                                      implicit_tag=self.implicit_tag,
531                                      explicit_tag=self.explicit_tag,
532                                      safe=self.flexible_tag)
533        if diff_tag is not None:
534            if self.implicit_tag is not None:
535                self.implicit_tag = diff_tag
536            elif self.explicit_tag is not None:
537                self.explicit_tag = diff_tag
538        p,s = self.extract_packet(self.cls, s)
539        return p,s
540    def i2m(self, pkt, x):
541        if x is None:
542            s = b""
543        else:
544            s = raw(x)
545        return BER_tagging_enc(s, implicit_tag=self.implicit_tag,
546                               explicit_tag=self.explicit_tag)
547    def randval(self):
548        return packet.fuzz(self.cls())
549
550class ASN1F_BIT_STRING_ENCAPS(ASN1F_BIT_STRING):
551    """
552    We may emulate simple string encapsulation with explicit_tag=0x04,
553    but we need a specific class for bit strings because of unused bits, etc.
554    """
555    holds_packets = 1
556    def __init__(self, name, default, cls, context=None,
557                 implicit_tag=None, explicit_tag=None):
558        self.cls = cls
559        ASN1F_BIT_STRING.__init__(self, name, None, context=context,
560                                  implicit_tag=implicit_tag,
561                                  explicit_tag=explicit_tag)
562        self.default = default
563    def m2i(self, pkt, s):
564        bit_string, remain = ASN1F_BIT_STRING.m2i(self, pkt, s)
565        if len(bit_string.val) % 8 != 0:
566            raise BER_Decoding_Error("wrong bit string", remaining=s)
567        p,s = self.extract_packet(self.cls, bit_string.val_readable)
568        if len(s) > 0:
569            raise BER_Decoding_Error("unexpected remainder", remaining=s)
570        return p, remain
571    def i2m(self, pkt, x):
572        if x is None:
573            s = b""
574        else:
575            s = raw(x)
576        s = b"".join(binrepr(orb(x)).zfill(8).encode("utf8") for x in s)
577        return ASN1F_BIT_STRING.i2m(self, pkt, s)
578
579class ASN1F_FLAGS(ASN1F_BIT_STRING):
580    def __init__(self, name, default, mapping, context=None,
581                 implicit_tag=None, explicit_tag=None):
582        self.mapping = mapping
583        ASN1F_BIT_STRING.__init__(self, name, default,
584                                  default_readable=False,
585                                  context=context,
586                                  implicit_tag=implicit_tag,
587                                  explicit_tag=explicit_tag)
588    def get_flags(self, pkt):
589        fbytes = getattr(pkt, self.name).val
590        flags = []
591        for i, positional in enumerate(fbytes):
592            if positional == '1' and i < len(self.mapping):
593                flags.append(self.mapping[i])
594        return flags
595    def i2repr(self, pkt, x):
596        if x is not None:
597            pretty_s = ", ".join(self.get_flags(pkt))
598            return pretty_s + " " + repr(x)
599        return repr(x)
600