1## This file is part of Scapy
2## Copyright (C) 2007, 2008, 2009 Arnaud Ebalard
3##               2015, 2016, 2017 Maxence Tury
4## This program is published under a GPLv2 license
5
6"""
7TLS handshake fields & logic.
8
9This module covers the handshake TLS subprotocol, except for the key exchange
10mechanisms which are addressed with keyexchange.py.
11"""
12
13from __future__ import absolute_import
14import math
15
16from scapy.error import log_runtime, warning
17from scapy.fields import *
18from scapy.compat import *
19from scapy.packet import Packet, Raw, Padding
20from scapy.utils import repr_hex
21from scapy.layers.x509 import OCSP_Response
22from scapy.layers.tls.cert import Cert, PrivKey, PubKey
23from scapy.layers.tls.basefields import (_tls_version, _TLSVersionField,
24                                         _TLSClientVersionField)
25from scapy.layers.tls.extensions import (_ExtensionsLenField, _ExtensionsField,
26                                         _cert_status_type, TLS_Ext_SupportedVersions)
27from scapy.layers.tls.keyexchange import (_TLSSignature, _TLSServerParamsField,
28                                          _TLSSignatureField, ServerRSAParams,
29                                          SigAndHashAlgsField, _tls_hash_sig,
30                                          SigAndHashAlgsLenField)
31from scapy.layers.tls.keyexchange_tls13 import TicketField
32from scapy.layers.tls.session import (_GenericTLSSessionInheritance,
33                                      readConnState, writeConnState)
34from scapy.layers.tls.crypto.compression import (_tls_compression_algs,
35                                                 _tls_compression_algs_cls,
36                                                 Comp_NULL, _GenericComp,
37                                                 _GenericCompMetaclass)
38from scapy.layers.tls.crypto.suites import (_tls_cipher_suites,
39                                            _tls_cipher_suites_cls,
40                                            _GenericCipherSuite,
41                                            _GenericCipherSuiteMetaclass)
42
43
44###############################################################################
45### Generic TLS Handshake message                                           ###
46###############################################################################
47
48_tls_handshake_type = { 0: "hello_request",         1: "client_hello",
49                        2: "server_hello",          3: "hello_verify_request",
50                        4: "session_ticket",        6: "hello_retry_request",
51                        8: "encrypted_extensions",  11: "certificate",
52                        12: "server_key_exchange",  13: "certificate_request",
53                        14: "server_hello_done",    15: "certificate_verify",
54                        16: "client_key_exchange",  20: "finished",
55                        21: "certificate_url",      22: "certificate_status",
56                        23: "supplemental_data",    24: "key_update" }
57
58
59class _TLSHandshake(_GenericTLSSessionInheritance):
60    """
61    Inherited by other Handshake classes to get post_build().
62    Also used as a fallback for unknown TLS Handshake packets.
63    """
64    name = "TLS Handshake Generic message"
65    fields_desc = [ ByteEnumField("msgtype", None, _tls_handshake_type),
66                    ThreeBytesField("msglen", None),
67                    StrLenField("msg", "",
68                                length_from=lambda pkt: pkt.msglen) ]
69
70    def post_build(self, p, pay):
71        l = len(p)
72        if self.msglen is None:
73            l2 = l - 4
74            p = struct.pack("!I", (orb(p[0]) << 24) | l2) + p[4:]
75        return p + pay
76
77    def guess_payload_class(self, p):
78        return conf.padding_layer
79
80    def tls_session_update(self, msg_str):
81        """
82        Covers both post_build- and post_dissection- context updates.
83        """
84        self.tls_session.handshake_messages.append(msg_str)
85        self.tls_session.handshake_messages_parsed.append(self)
86
87
88###############################################################################
89### HelloRequest                                                            ###
90###############################################################################
91
92class TLSHelloRequest(_TLSHandshake):
93    name = "TLS Handshake - Hello Request"
94    fields_desc = [ ByteEnumField("msgtype", 0, _tls_handshake_type),
95                    ThreeBytesField("msglen", None) ]
96
97    def tls_session_update(self, msg_str):
98        """
99        Message should not be added to the list of handshake messages
100        that will be hashed in the finished and certificate verify messages.
101        """
102        return
103
104
105###############################################################################
106### ClientHello fields                                                      ###
107###############################################################################
108
109class _GMTUnixTimeField(UTCTimeField):
110    """
111    "The current time and date in standard UNIX 32-bit format (seconds since
112     the midnight starting Jan 1, 1970, GMT, ignoring leap seconds)."
113    """
114    def i2h(self, pkt, x):
115        if x is not None:
116            return x
117        return 0
118
119class _TLSRandomBytesField(StrFixedLenField):
120    def i2repr(self, pkt, x):
121        if x is None:
122            return repr(x)
123        return repr_hex(self.i2h(pkt,x))
124
125
126class _SessionIDField(StrLenField):
127    """
128    opaque SessionID<0..32>; section 7.4.1.2 of RFC 4346
129    """
130    pass
131
132
133class _CipherSuitesField(StrLenField):
134    __slots__ = ["itemfmt", "itemsize", "i2s", "s2i"]
135    islist = 1
136    def __init__(self, name, default, dico, length_from=None, itemfmt="!H"):
137        StrLenField.__init__(self, name, default, length_from=length_from)
138        self.itemfmt = itemfmt
139        self.itemsize = struct.calcsize(itemfmt)
140        i2s = self.i2s = {}
141        s2i = self.s2i = {}
142        for k in six.iterkeys(dico):
143            i2s[k] = dico[k]
144            s2i[dico[k]] = k
145
146    def any2i_one(self, pkt, x):
147        if (isinstance(x, _GenericCipherSuite) or
148            isinstance(x, _GenericCipherSuiteMetaclass)):
149            x = x.val
150        if isinstance(x, bytes):
151            x = self.s2i[x]
152        return x
153
154    def i2repr_one(self, pkt, x):
155        fmt = "0x%%0%dx" % self.itemsize
156        return self.i2s.get(x, fmt % x)
157
158    def any2i(self, pkt, x):
159        if x is None:
160            return None
161        if not isinstance(x, list):
162            x = [x]
163        return [self.any2i_one(pkt, z) for z in x]
164
165    def i2repr(self, pkt, x):
166        if x is None:
167            return "None"
168        l = [self.i2repr_one(pkt, z) for z in x]
169        if len(l) == 1:
170            l = l[0]
171        else:
172            l = "[%s]" % ", ".join(l)
173        return l
174
175    def i2m(self, pkt, val):
176        if val is None:
177            val = []
178        return b"".join(struct.pack(self.itemfmt, x) for x in val)
179
180    def m2i(self, pkt, m):
181        res = []
182        itemlen = struct.calcsize(self.itemfmt)
183        while m:
184            res.append(struct.unpack(self.itemfmt, m[:itemlen])[0])
185            m = m[itemlen:]
186        return res
187
188    def i2len(self, pkt, i):
189        if i is None:
190            return 0
191        return len(i)*self.itemsize
192
193
194class _CompressionMethodsField(_CipherSuitesField):
195
196    def any2i_one(self, pkt, x):
197        if (isinstance(x, _GenericComp) or
198            isinstance(x, _GenericCompMetaclass)):
199            x = x.val
200        if isinstance(x, str):
201            x = self.s2i[x]
202        return x
203
204
205###############################################################################
206### ClientHello                                                             ###
207###############################################################################
208
209class TLSClientHello(_TLSHandshake):
210    """
211    TLS ClientHello, with abilities to handle extensions.
212
213    The Random structure follows the RFC 5246: while it is 32-byte long,
214    many implementations use the first 4 bytes as a gmt_unix_time, and then
215    the remaining 28 byts should be completely random. This was designed in
216    order to (sort of) mitigate broken RNGs. If you prefer to show the full
217    32 random bytes without any GMT time, just comment in/out the lines below.
218    """
219    name = "TLS Handshake - Client Hello"
220    fields_desc = [ ByteEnumField("msgtype", 1, _tls_handshake_type),
221                    ThreeBytesField("msglen", None),
222                    _TLSClientVersionField("version", None, _tls_version),
223
224                    #_TLSRandomBytesField("random_bytes", None, 32),
225                    _GMTUnixTimeField("gmt_unix_time", None),
226                    _TLSRandomBytesField("random_bytes", None, 28),
227
228                    FieldLenField("sidlen", None, fmt="B", length_of="sid"),
229                    _SessionIDField("sid", "",
230                                    length_from=lambda pkt:pkt.sidlen),
231
232                    FieldLenField("cipherslen", None, fmt="!H",
233                                  length_of="ciphers"),
234                    _CipherSuitesField("ciphers", None,
235                                       _tls_cipher_suites, itemfmt="!H",
236                                       length_from=lambda pkt: pkt.cipherslen),
237
238                    FieldLenField("complen", None, fmt="B", length_of="comp"),
239                    _CompressionMethodsField("comp", [0],
240                                             _tls_compression_algs,
241                                             itemfmt="B",
242                                             length_from=
243                                                 lambda pkt: pkt.complen),
244
245                    _ExtensionsLenField("extlen", None, length_of="ext"),
246                    _ExtensionsField("ext", None,
247                                     length_from=lambda pkt: (pkt.msglen -
248                                                              (pkt.sidlen or 0) -
249                                                              (pkt.cipherslen or 0) -
250                                                              (pkt.complen or 0) -
251                                                              40)) ]
252
253    def post_build(self, p, pay):
254        if self.random_bytes is None:
255            p = p[:10] + randstring(28) + p[10+28:]
256
257        # if no ciphersuites were provided, we add a few usual, supported
258        # ciphersuites along with the appropriate extensions
259        if self.ciphers is None:
260            cipherstart = 39 + (self.sidlen or 0)
261            s = b"001ac02bc023c02fc027009e0067009c003cc009c0130033002f000a"
262            p = p[:cipherstart] + bytes_hex(s) + p[cipherstart+2:]
263            if self.ext is None:
264                ext_len = b'\x00\x2c'
265                ext_reneg = b'\xff\x01\x00\x01\x00'
266                ext_sn = b'\x00\x00\x00\x0f\x00\r\x00\x00\nsecdev.org'
267                ext_sigalg = b'\x00\r\x00\x08\x00\x06\x04\x03\x04\x01\x02\x01'
268                ext_supgroups = b'\x00\n\x00\x04\x00\x02\x00\x17'
269                p += ext_len + ext_reneg + ext_sn + ext_sigalg + ext_supgroups
270
271        return super(TLSClientHello, self).post_build(p, pay)
272
273    def tls_session_update(self, msg_str):
274        """
275        Either for parsing or building, we store the client_random
276        along with the raw string representing this handshake message.
277        """
278        super(TLSClientHello, self).tls_session_update(msg_str)
279
280        self.tls_session.advertised_tls_version = self.version
281        self.random_bytes = msg_str[10:38]
282        self.tls_session.client_random = (struct.pack('!I',
283                                                      self.gmt_unix_time) +
284                                          self.random_bytes)
285        if self.ext:
286            for e in self.ext:
287                if isinstance(e, TLS_Ext_SupportedVersions):
288                    if self.tls_session.tls13_early_secret is None:
289                        # this is not recomputed if there was a TLS 1.3 HRR
290                        self.tls_session.compute_tls13_early_secrets()
291                    break
292
293###############################################################################
294### ServerHello                                                             ###
295###############################################################################
296
297class TLSServerHello(TLSClientHello):
298    """
299    TLS ServerHello, with abilities to handle extensions.
300
301    The Random structure follows the RFC 5246: while it is 32-byte long,
302    many implementations use the first 4 bytes as a gmt_unix_time, and then
303    the remaining 28 byts should be completely random. This was designed in
304    order to (sort of) mitigate broken RNGs. If you prefer to show the full
305    32 random bytes without any GMT time, just comment in/out the lines below.
306    """
307    name = "TLS Handshake - Server Hello"
308    fields_desc = [ ByteEnumField("msgtype", 2, _tls_handshake_type),
309                    ThreeBytesField("msglen", None),
310                    _TLSVersionField("version", None, _tls_version),
311
312                    #_TLSRandomBytesField("random_bytes", None, 32),
313                    _GMTUnixTimeField("gmt_unix_time", None),
314                    _TLSRandomBytesField("random_bytes", None, 28),
315
316                    FieldLenField("sidlen", None, length_of="sid", fmt="B"),
317                    _SessionIDField("sid", "",
318                                   length_from = lambda pkt: pkt.sidlen),
319
320                    EnumField("cipher", None, _tls_cipher_suites),
321                    _CompressionMethodsField("comp", [0],
322                                             _tls_compression_algs,
323                                             itemfmt="B",
324                                             length_from=lambda pkt: 1),
325
326                    _ExtensionsLenField("extlen", None, length_of="ext"),
327                    _ExtensionsField("ext", None,
328                                     length_from=lambda pkt: (pkt.msglen -
329                                                              (pkt.sidlen or 0) -
330                                                              38)) ]
331                                                              #40)) ]
332
333    @classmethod
334    def dispatch_hook(cls, _pkt=None, *args, **kargs):
335        if _pkt and len(_pkt) >= 6:
336            version = struct.unpack("!H", _pkt[4:6])[0]
337            if version == 0x0304 or version > 0x7f00:
338                return TLS13ServerHello
339        return TLSServerHello
340
341    def post_build(self, p, pay):
342        if self.random_bytes is None:
343            p = p[:10] + randstring(28) + p[10+28:]
344        return super(TLSClientHello, self).post_build(p, pay)
345
346    def tls_session_update(self, msg_str):
347        """
348        Either for parsing or building, we store the server_random
349        along with the raw string representing this handshake message.
350        We also store the session_id, the cipher suite (if recognized),
351        the compression method, and finally we instantiate the pending write
352        and read connection states. Usually they get updated later on in the
353        negotiation when we learn the session keys, and eventually they
354        are committed once a ChangeCipherSpec has been sent/received.
355        """
356        super(TLSClientHello, self).tls_session_update(msg_str)
357
358        self.tls_session.tls_version = self.version
359        self.random_bytes = msg_str[10:38]
360        self.tls_session.server_random = (struct.pack('!I',
361                                                      self.gmt_unix_time) +
362                                          self.random_bytes)
363        self.tls_session.sid = self.sid
364
365        cs_cls = None
366        if self.cipher:
367            cs_val = self.cipher
368            if cs_val not in _tls_cipher_suites_cls:
369                warning("Unknown cipher suite %d from ServerHello" % cs_val)
370                # we do not try to set a default nor stop the execution
371            else:
372                cs_cls = _tls_cipher_suites_cls[cs_val]
373
374        comp_cls = Comp_NULL
375        if self.comp:
376            comp_val = self.comp[0]
377            if comp_val not in _tls_compression_algs_cls:
378                err = "Unknown compression alg %d from ServerHello" % comp_val
379                warning(err)
380                comp_val = 0
381            comp_cls = _tls_compression_algs_cls[comp_val]
382
383        connection_end = self.tls_session.connection_end
384        self.tls_session.pwcs = writeConnState(ciphersuite=cs_cls,
385                                               compression_alg=comp_cls,
386                                               connection_end=connection_end,
387                                               tls_version=self.version)
388        self.tls_session.prcs = readConnState(ciphersuite=cs_cls,
389                                              compression_alg=comp_cls,
390                                              connection_end=connection_end,
391                                              tls_version=self.version)
392
393
394class TLS13ServerHello(TLSClientHello):
395    """ TLS 1.3 ServerHello """
396    name = "TLS 1.3 Handshake - Server Hello"
397    fields_desc = [ ByteEnumField("msgtype", 2, _tls_handshake_type),
398                    ThreeBytesField("msglen", None),
399                    _TLSVersionField("version", None, _tls_version),
400                    _TLSRandomBytesField("random_bytes", None, 32),
401                    EnumField("cipher", None, _tls_cipher_suites),
402                    _ExtensionsLenField("extlen", None, length_of="ext"),
403                    _ExtensionsField("ext", None,
404                                     length_from=lambda pkt: (pkt.msglen -
405                                                              38)) ]
406
407    def tls_session_update(self, msg_str):
408        """
409        Either for parsing or building, we store the server_random along with
410        the raw string representing this handshake message. We also store the
411        cipher suite (if recognized), and finally we instantiate the write and
412        read connection states.
413        """
414        super(TLSClientHello, self).tls_session_update(msg_str)
415
416        s = self.tls_session
417        s.tls_version = self.version
418        s.server_random = self.random_bytes
419
420        cs_cls = None
421        if self.cipher:
422            cs_val = self.cipher
423            if cs_val not in _tls_cipher_suites_cls:
424                warning("Unknown cipher suite %d from ServerHello" % cs_val)
425                # we do not try to set a default nor stop the execution
426            else:
427                cs_cls = _tls_cipher_suites_cls[cs_val]
428
429        connection_end = s.connection_end
430        s.pwcs = writeConnState(ciphersuite=cs_cls,
431                                connection_end=connection_end,
432                                tls_version=self.version)
433        s.triggered_pwcs_commit = True
434        s.prcs = readConnState(ciphersuite=cs_cls,
435                               connection_end=connection_end,
436                               tls_version=self.version)
437        s.triggered_prcs_commit = True
438
439        if self.tls_session.tls13_early_secret is None:
440            # In case the connState was not pre-initialized, we could not
441            # compute the early secrets at the ClientHello, so we do it here.
442            self.tls_session.compute_tls13_early_secrets()
443        s.compute_tls13_handshake_secrets()
444
445
446###############################################################################
447### HelloRetryRequest                                                       ###
448###############################################################################
449
450class TLSHelloRetryRequest(_TLSHandshake):
451    name = "TLS 1.3 Handshake - Hello Retry Request"
452    fields_desc = [ ByteEnumField("msgtype", 6, _tls_handshake_type),
453                    ThreeBytesField("msglen", None),
454                    _TLSVersionField("version", None, _tls_version),
455                    _ExtensionsLenField("extlen", None, length_of="ext"),
456                    _ExtensionsField("ext", None,
457                                     length_from=lambda pkt: pkt.msglen - 4) ]
458
459
460###############################################################################
461### EncryptedExtensions                                                     ###
462###############################################################################
463
464class TLSEncryptedExtensions(_TLSHandshake):
465    name = "TLS 1.3 Handshake - Encrypted Extensions"
466    fields_desc = [ ByteEnumField("msgtype", 8, _tls_handshake_type),
467                    ThreeBytesField("msglen", None),
468                    _ExtensionsLenField("extlen", None, length_of="ext"),
469                    _ExtensionsField("ext", None,
470                                     length_from=lambda pkt: pkt.msglen - 2) ]
471
472
473###############################################################################
474### Certificate                                                             ###
475###############################################################################
476
477#XXX It might be appropriate to rewrite this mess with basic 3-byte FieldLenField.
478
479class _ASN1CertLenField(FieldLenField):
480    """
481    This is mostly a 3-byte FieldLenField.
482    """
483    def __init__(self, name, default, length_of=None, adjust=lambda pkt, x: x):
484        self.length_of = length_of
485        self.adjust = adjust
486        Field.__init__(self, name, default, fmt="!I")
487
488    def i2m(self, pkt, x):
489        if x is None:
490            if self.length_of is not None:
491                fld,fval = pkt.getfield_and_val(self.length_of)
492                f = fld.i2len(pkt, fval)
493                x = self.adjust(pkt, f)
494        return x
495
496    def addfield(self, pkt, s, val):
497        return s + struct.pack(self.fmt, self.i2m(pkt,val))[1:4]
498
499    def getfield(self, pkt, s):
500        return s[3:], self.m2i(pkt, struct.unpack(self.fmt, b"\x00" + s[:3])[0])
501
502
503class _ASN1CertListField(StrLenField):
504    islist = 1
505    def i2len(self, pkt, i):
506        if i is None:
507            return 0
508        return len(self.i2m(pkt, i))
509
510    def getfield(self, pkt, s):
511        """
512        Extract Certs in a loop.
513        XXX We should provide safeguards when trying to parse a Cert.
514        """
515        l = None
516        if self.length_from is not None:
517            l = self.length_from(pkt)
518
519        lst = []
520        ret = b""
521        m = s
522        if l is not None:
523            m, ret = s[:l], s[l:]
524        while m:
525            clen = struct.unpack("!I", b'\x00' + m[:3])[0]
526            lst.append((clen, Cert(m[3:3 + clen])))
527            m = m[3 + clen:]
528        return m + ret, lst
529
530    def i2m(self, pkt, i):
531        def i2m_one(i):
532            if isinstance(i, str):
533                return i
534            if isinstance(i, Cert):
535                s = i.der
536                l = struct.pack("!I", len(s))[1:4]
537                return l + s
538
539            (l, s) = i
540            if isinstance(s, Cert):
541                s = s.der
542            return struct.pack("!I", l)[1:4] + s
543
544        if i is None:
545            return b""
546        if isinstance(i, str):
547            return i
548        if isinstance(i, Cert):
549            i = [i]
550        return b"".join(i2m_one(x) for x in i)
551
552    def any2i(self, pkt, x):
553        return x
554
555class _ASN1CertField(StrLenField):
556    def i2len(self, pkt, i):
557        if i is None:
558            return 0
559        return len(self.i2m(pkt, i))
560
561    def getfield(self, pkt, s):
562        l = None
563        if self.length_from is not None:
564            l = self.length_from(pkt)
565        ret = b""
566        m = s
567        if l is not None:
568            m, ret = s[:l], s[l:]
569        clen = struct.unpack("!I", b'\x00' + m[:3])[0]
570        len_cert = (clen, Cert(m[3:3 + clen]))
571        m = m[3 + clen:]
572        return m + ret, len_cert
573
574    def i2m(self, pkt, i):
575        def i2m_one(i):
576            if isinstance(i, str):
577                return i
578            if isinstance(i, Cert):
579                s = i.der
580                l = struct.pack("!I", len(s))[1:4]
581                return l + s
582
583            (l, s) = i
584            if isinstance(s, Cert):
585                s = s.der
586            return struct.pack("!I", l)[1:4] + s
587
588        if i is None:
589            return b""
590        return i2m_one(i)
591
592    def any2i(self, pkt, x):
593        return x
594
595
596class TLSCertificate(_TLSHandshake):
597    """
598    XXX We do not support RFC 5081, i.e. OpenPGP certificates.
599    """
600    name = "TLS Handshake - Certificate"
601    fields_desc = [ ByteEnumField("msgtype", 11, _tls_handshake_type),
602                    ThreeBytesField("msglen", None),
603                    _ASN1CertLenField("certslen", None, length_of="certs"),
604                    _ASN1CertListField("certs", [],
605                                      length_from = lambda pkt: pkt.certslen) ]
606
607    @classmethod
608    def dispatch_hook(cls, _pkt=None, *args, **kargs):
609        if _pkt:
610            tls_session = kargs.get("tls_session", None)
611            if tls_session and (tls_session.tls_version or 0) >= 0x0304:
612                return TLS13Certificate
613        return TLSCertificate
614
615    def post_dissection_tls_session_update(self, msg_str):
616        self.tls_session_update(msg_str)
617        connection_end = self.tls_session.connection_end
618        if connection_end == "client":
619            self.tls_session.server_certs = [x[1] for x in self.certs]
620        else:
621            self.tls_session.client_certs = [x[1] for x in self.certs]
622
623
624class _ASN1CertAndExt(_GenericTLSSessionInheritance):
625    name = "Certificate and Extensions"
626    fields_desc = [ _ASN1CertField("cert", ""),
627                    FieldLenField("extlen", None, length_of="ext"),
628                    _ExtensionsField("ext", [],
629                                     length_from=lambda pkt: pkt.extlen) ]
630    def extract_padding(self, s):
631        return b"", s
632
633class _ASN1CertAndExtListField(PacketListField):
634    def m2i(self, pkt, m):
635        return self.cls(m, tls_session=pkt.tls_session)
636
637class TLS13Certificate(_TLSHandshake):
638    name = "TLS 1.3 Handshake - Certificate"
639    fields_desc = [ ByteEnumField("msgtype", 11, _tls_handshake_type),
640                    ThreeBytesField("msglen", None),
641                    FieldLenField("cert_req_ctxt_len", None, fmt="B",
642                                  length_of="cert_req_ctxt"),
643                    StrLenField("cert_req_ctxt", "",
644                                length_from=lambda pkt: pkt.cert_req_ctxt_len),
645                    _ASN1CertLenField("certslen", None, length_of="certs"),
646                    _ASN1CertAndExtListField("certs", [], _ASN1CertAndExt,
647                                      length_from=lambda pkt: pkt.certslen) ]
648
649    def post_dissection_tls_session_update(self, msg_str):
650        self.tls_session_update(msg_str)
651        connection_end = self.tls_session.connection_end
652        if connection_end == "client":
653            if self.certs:
654                sc = [x.cert[1] for x in self.certs]
655                self.tls_session.server_certs = sc
656        else:
657            if self.certs:
658                cc = [x.cert[1] for x in self.certs]
659                self.tls_session.client_certs = cc
660
661
662###############################################################################
663### ServerKeyExchange                                                       ###
664###############################################################################
665
666class TLSServerKeyExchange(_TLSHandshake):
667    name = "TLS Handshake - Server Key Exchange"
668    fields_desc = [ ByteEnumField("msgtype", 12, _tls_handshake_type),
669                    ThreeBytesField("msglen", None),
670                    _TLSServerParamsField("params", None,
671                        length_from=lambda pkt: pkt.msglen),
672                    _TLSSignatureField("sig", None,
673                        length_from=lambda pkt: pkt.msglen - len(pkt.params)) ]
674
675    def build(self, *args, **kargs):
676        """
677        We overload build() method in order to provide a valid default value
678        for params based on TLS session if not provided. This cannot be done by
679        overriding i2m() because the method is called on a copy of the packet.
680
681        The 'params' field is built according to key_exchange.server_kx_msg_cls
682        which should have been set after receiving a cipher suite in a
683        previous ServerHello. Usual cases are:
684        - None: for RSA encryption or fixed FF/ECDH. This should never happen,
685          as no ServerKeyExchange should be generated in the first place.
686        - ServerDHParams: for ephemeral FFDH. In that case, the parameter to
687          server_kx_msg_cls does not matter.
688        - ServerECDH*Params: for ephemeral ECDH. There are actually three
689          classes, which are dispatched by _tls_server_ecdh_cls_guess on
690          the first byte retrieved. The default here is b"\03", which
691          corresponds to ServerECDHNamedCurveParams (implicit curves).
692
693        When the Server*DHParams are built via .fill_missing(), the session
694        server_kx_privkey will be updated accordingly.
695        """
696        fval = self.getfieldval("params")
697        if fval is None:
698            s = self.tls_session
699            if s.pwcs:
700                if s.pwcs.key_exchange.export:
701                    cls = ServerRSAParams(tls_session=s)
702                else:
703                    cls = s.pwcs.key_exchange.server_kx_msg_cls(b"\x03")
704                    cls = cls(tls_session=s)
705                try:
706                    cls.fill_missing()
707                except:
708                    pass
709            else:
710                cls = Raw()
711            self.params = cls
712
713        fval = self.getfieldval("sig")
714        if fval is None:
715            s = self.tls_session
716            if s.pwcs:
717                if not s.pwcs.key_exchange.anonymous:
718                    p = self.params
719                    if p is None:
720                        p = b""
721                    m = s.client_random + s.server_random + raw(p)
722                    cls = _TLSSignature(tls_session=s)
723                    cls._update_sig(m, s.server_key)
724                else:
725                    cls = Raw()
726            else:
727                cls = Raw()
728            self.sig = cls
729
730        return _TLSHandshake.build(self, *args, **kargs)
731
732    def post_dissection(self, pkt):
733        """
734        While previously dissecting Server*DHParams, the session
735        server_kx_pubkey should have been updated.
736
737        XXX Add a 'fixed_dh' OR condition to the 'anonymous' test.
738        """
739        s = self.tls_session
740        if s.prcs and s.prcs.key_exchange.no_ske:
741            pkt_info = pkt.firstlayer().summary()
742            log_runtime.info("TLS: useless ServerKeyExchange [%s]", pkt_info)
743        if (s.prcs and
744            not s.prcs.key_exchange.anonymous and
745            s.client_random and s.server_random and
746            s.server_certs and len(s.server_certs) > 0):
747            m = s.client_random + s.server_random + raw(self.params)
748            sig_test = self.sig._verify_sig(m, s.server_certs[0])
749            if not sig_test:
750                pkt_info = pkt.firstlayer().summary()
751                log_runtime.info("TLS: invalid ServerKeyExchange signature [%s]", pkt_info)
752
753
754###############################################################################
755### CertificateRequest                                                      ###
756###############################################################################
757
758_tls_client_certificate_types =  {  1: "rsa_sign",
759                                    2: "dss_sign",
760                                    3: "rsa_fixed_dh",
761                                    4: "dss_fixed_dh",
762                                    5: "rsa_ephemeral_dh_RESERVED",
763                                    6: "dss_ephemeral_dh_RESERVED",
764                                   20: "fortezza_dms_RESERVED",
765                                   64: "ecdsa_sign",
766                                   65: "rsa_fixed_ecdh",
767                                   66: "ecdsa_fixed_ecdh" }
768
769
770class _CertTypesField(_CipherSuitesField):
771    pass
772
773class _CertAuthoritiesField(StrLenField):
774    """
775    XXX Rework this with proper ASN.1 parsing.
776    """
777    islist = 1
778
779    def getfield(self, pkt, s):
780        l = self.length_from(pkt)
781        return s[l:], self.m2i(pkt, s[:l])
782
783    def m2i(self, pkt, m):
784        res = []
785        while len(m) > 1:
786            l = struct.unpack("!H", m[:2])[0]
787            if len(m) < l + 2:
788                res.append((l, m[2:]))
789                break
790            dn = m[2:2+l]
791            res.append((l, dn))
792            m = m[2+l:]
793        return res
794
795    def i2m(self, pkt, i):
796        return b"".join(map(lambda x_y: struct.pack("!H", x_y[0]) + x_y[1], i))
797
798    def addfield(self, pkt, s, val):
799        return s + self.i2m(pkt, val)
800
801    def i2len(self, pkt, val):
802        if val is None:
803            return 0
804        else:
805            return len(self.i2m(pkt, val))
806
807
808class TLSCertificateRequest(_TLSHandshake):
809    name = "TLS Handshake - Certificate Request"
810    fields_desc = [ ByteEnumField("msgtype", 13, _tls_handshake_type),
811                    ThreeBytesField("msglen", None),
812                    FieldLenField("ctypeslen", None, fmt="B",
813                                  length_of="ctypes"),
814                    _CertTypesField("ctypes", [1, 64],
815                                    _tls_client_certificate_types,
816                                    itemfmt="!B",
817                                    length_from=lambda pkt: pkt.ctypeslen),
818                    SigAndHashAlgsLenField("sig_algs_len", None,
819                                           length_of="sig_algs"),
820                    SigAndHashAlgsField("sig_algs", [0x0403, 0x0401, 0x0201],
821                                EnumField("hash_sig", None, _tls_hash_sig),
822                                length_from=lambda pkt: pkt.sig_algs_len),
823                    FieldLenField("certauthlen", None, fmt="!H",
824                                  length_of="certauth"),
825                    _CertAuthoritiesField("certauth", [],
826                                length_from=lambda pkt: pkt.certauthlen) ]
827
828
829###############################################################################
830### ServerHelloDone                                                         ###
831###############################################################################
832
833class TLSServerHelloDone(_TLSHandshake):
834    name = "TLS Handshake - Server Hello Done"
835    fields_desc = [ ByteEnumField("msgtype", 14, _tls_handshake_type),
836                    ThreeBytesField("msglen", None) ]
837
838
839###############################################################################
840### CertificateVerify                                                       ###
841###############################################################################
842
843class TLSCertificateVerify(_TLSHandshake):
844    name = "TLS Handshake - Certificate Verify"
845    fields_desc = [ ByteEnumField("msgtype", 15, _tls_handshake_type),
846                    ThreeBytesField("msglen", None),
847                    _TLSSignatureField("sig", None,
848                                 length_from=lambda pkt: pkt.msglen) ]
849
850    def build(self, *args, **kargs):
851        sig = self.getfieldval("sig")
852        if sig is None:
853            s = self.tls_session
854            m = b"".join(s.handshake_messages)
855            if s.tls_version >= 0x0304:
856                if s.connection_end == "client":
857                    context_string = "TLS 1.3, client CertificateVerify"
858                elif s.connection_end == "server":
859                    context_string = "TLS 1.3, server CertificateVerify"
860                m = b"\x20"*64 + context_string + b"\x00" + s.wcs.hash.digest(m)
861            self.sig = _TLSSignature(tls_session=s)
862            if s.connection_end == "client":
863                self.sig._update_sig(m, s.client_key)
864            elif s.connection_end == "server":
865                # should be TLS 1.3 only
866                self.sig._update_sig(m, s.server_key)
867        return _TLSHandshake.build(self, *args, **kargs)
868
869    def post_dissection(self, pkt):
870        s = self.tls_session
871        m = b"".join(s.handshake_messages)
872        if s.tls_version >= 0x0304:
873            if s.connection_end == "client":
874                context_string = b"TLS 1.3, server CertificateVerify"
875            elif s.connection_end == "server":
876                context_string = b"TLS 1.3, client CertificateVerify"
877            m = b"\x20"*64 + context_string + b"\x00" + s.rcs.hash.digest(m)
878
879        if s.connection_end == "server":
880            if s.client_certs and len(s.client_certs) > 0:
881                sig_test = self.sig._verify_sig(m, s.client_certs[0])
882                if not sig_test:
883                    pkt_info = pkt.firstlayer().summary()
884                    log_runtime.info("TLS: invalid CertificateVerify signature [%s]", pkt_info)
885        elif s.connection_end == "client":
886            # should be TLS 1.3 only
887            if s.server_certs and len(s.server_certs) > 0:
888                sig_test = self.sig._verify_sig(m, s.server_certs[0])
889                if not sig_test:
890                    pkt_info = pkt.firstlayer().summary()
891                    log_runtime.info("TLS: invalid CertificateVerify signature [%s]", pkt_info)
892
893
894###############################################################################
895### ClientKeyExchange                                                       ###
896###############################################################################
897
898class _TLSCKExchKeysField(PacketField):
899    __slots__ = ["length_from"]
900    holds_packet = 1
901    def __init__(self, name, length_from=None, remain=0):
902        self.length_from = length_from
903        PacketField.__init__(self, name, None, None, remain=remain)
904
905    def m2i(self, pkt, m):
906        """
907        The client_kx_msg may be either None, EncryptedPreMasterSecret
908        (for RSA encryption key exchange), ClientDiffieHellmanPublic,
909        or ClientECDiffieHellmanPublic. When either one of them gets
910        dissected, the session context is updated accordingly.
911        """
912        l = self.length_from(pkt)
913        tbd, rem = m[:l], m[l:]
914
915        s = pkt.tls_session
916        cls = None
917
918        if s.prcs and s.prcs.key_exchange:
919            cls = s.prcs.key_exchange.client_kx_msg_cls
920
921        if cls is None:
922            return Raw(tbd)/Padding(rem)
923
924        return cls(tbd, tls_session=s)/Padding(rem)
925
926
927class TLSClientKeyExchange(_TLSHandshake):
928    """
929    This class mostly works like TLSServerKeyExchange and its 'params' field.
930    """
931    name = "TLS Handshake - Client Key Exchange"
932    fields_desc = [ ByteEnumField("msgtype", 16, _tls_handshake_type),
933                    ThreeBytesField("msglen", None),
934                    _TLSCKExchKeysField("exchkeys",
935                                        length_from = lambda pkt: pkt.msglen) ]
936
937    def build(self, *args, **kargs):
938        fval = self.getfieldval("exchkeys")
939        if fval is None:
940            s = self.tls_session
941            if s.prcs:
942                cls = s.prcs.key_exchange.client_kx_msg_cls
943                cls = cls(tls_session=s)
944            else:
945                cls = Raw()
946            self.exchkeys = cls
947        return _TLSHandshake.build(self, *args, **kargs)
948
949
950###############################################################################
951### Finished                                                                ###
952###############################################################################
953
954class _VerifyDataField(StrLenField):
955    def getfield(self, pkt, s):
956        if pkt.tls_session.tls_version == 0x0300:
957            sep = 36
958        elif pkt.tls_session.tls_version >= 0x0304:
959            sep = pkt.tls_session.rcs.hash.hash_len
960        else:
961            sep = 12
962        return s[sep:], s[:sep]
963
964class TLSFinished(_TLSHandshake):
965    name = "TLS Handshake - Finished"
966    fields_desc = [ ByteEnumField("msgtype", 20, _tls_handshake_type),
967                    ThreeBytesField("msglen", None),
968                    _VerifyDataField("vdata", None) ]
969
970    def build(self, *args, **kargs):
971        fval = self.getfieldval("vdata")
972        if fval is None:
973            s = self.tls_session
974            handshake_msg = b"".join(s.handshake_messages)
975            con_end = s.connection_end
976            if s.tls_version < 0x0304:
977                ms = s.master_secret
978                self.vdata = s.wcs.prf.compute_verify_data(con_end, "write",
979                                                           handshake_msg, ms)
980            else:
981                self.vdata = s.compute_tls13_verify_data(con_end, "write")
982        return _TLSHandshake.build(self, *args, **kargs)
983
984    def post_dissection(self, pkt):
985        s = self.tls_session
986        if not s.frozen:
987            handshake_msg = b"".join(s.handshake_messages)
988            if s.tls_version < 0x0304 and s.master_secret is not None:
989                ms = s.master_secret
990                con_end = s.connection_end
991                verify_data = s.rcs.prf.compute_verify_data(con_end, "read",
992                                                            handshake_msg, ms)
993                if self.vdata != verify_data:
994                    pkt_info = pkt.firstlayer().summary()
995                    log_runtime.info("TLS: invalid Finished received [%s]", pkt_info)
996            elif s.tls_version >= 0x0304:
997                con_end = s.connection_end
998                verify_data = s.compute_tls13_verify_data(con_end, "read")
999                if self.vdata != verify_data:
1000                    pkt_info = pkt.firstlayer().summary()
1001                    log_runtime.info("TLS: invalid Finished received [%s]", pkt_info)
1002
1003    def post_build_tls_session_update(self, msg_str):
1004        self.tls_session_update(msg_str)
1005        s = self.tls_session
1006        if s.tls_version >= 0x0304:
1007            s.pwcs = writeConnState(ciphersuite=type(s.wcs.ciphersuite),
1008                                    connection_end=s.connection_end,
1009                                    tls_version=s.tls_version)
1010            s.triggered_pwcs_commit = True
1011            if s.connection_end == "server":
1012                s.compute_tls13_traffic_secrets()
1013            elif s.connection_end == "client":
1014                s.compute_tls13_traffic_secrets_end()
1015                s.compute_tls13_resumption_secret()
1016
1017    def post_dissection_tls_session_update(self, msg_str):
1018        self.tls_session_update(msg_str)
1019        s = self.tls_session
1020        if s.tls_version >= 0x0304:
1021            s.prcs = readConnState(ciphersuite=type(s.rcs.ciphersuite),
1022                                   connection_end=s.connection_end,
1023                                   tls_version=s.tls_version)
1024            s.triggered_prcs_commit = True
1025            if s.connection_end == "client":
1026                s.compute_tls13_traffic_secrets()
1027            elif s.connection_end == "server":
1028                s.compute_tls13_traffic_secrets_end()
1029                s.compute_tls13_resumption_secret()
1030
1031
1032## Additional handshake messages
1033
1034###############################################################################
1035### HelloVerifyRequest                                                      ###
1036###############################################################################
1037
1038class TLSHelloVerifyRequest(_TLSHandshake):
1039    """
1040    Defined for DTLS, see RFC 6347.
1041    """
1042    name = "TLS Handshake - Hello Verify Request"
1043    fields_desc = [ ByteEnumField("msgtype", 21, _tls_handshake_type),
1044                    ThreeBytesField("msglen", None),
1045                    FieldLenField("cookielen", None,
1046                                  fmt="B", length_of="cookie"),
1047                    StrLenField("cookie", "",
1048                                length_from=lambda pkt: pkt.cookielen) ]
1049
1050
1051###############################################################################
1052### CertificateURL                                                          ###
1053###############################################################################
1054
1055_tls_cert_chain_types = { 0: "individual_certs",
1056                          1: "pkipath" }
1057
1058class URLAndOptionalHash(Packet):
1059    name = "URLAndOptionHash structure for TLSCertificateURL"
1060    fields_desc = [ FieldLenField("urllen", None, length_of="url"),
1061                    StrLenField("url", "",
1062                                length_from=lambda pkt: pkt.urllen),
1063                    FieldLenField("hash_present", None,
1064                                  fmt="B", length_of="hash",
1065                                  adjust=lambda pkt,x: int(math.ceil(x/20.))),
1066                    StrLenField("hash", "",
1067                                length_from=lambda pkt: 20*pkt.hash_present) ]
1068    def guess_payload_class(self, p):
1069        return Padding
1070
1071class TLSCertificateURL(_TLSHandshake):
1072    """
1073    Defined in RFC 4366. PkiPath structure of section 8 is not implemented yet.
1074    """
1075    name = "TLS Handshake - Certificate URL"
1076    fields_desc = [ ByteEnumField("msgtype", 21, _tls_handshake_type),
1077                    ThreeBytesField("msglen", None),
1078                    ByteEnumField("certchaintype", None, _tls_cert_chain_types),
1079                    FieldLenField("uahlen", None, length_of="uah"),
1080                    PacketListField("uah", [], URLAndOptionalHash,
1081                                    length_from=lambda pkt: pkt.uahlen) ]
1082
1083
1084###############################################################################
1085### CertificateStatus                                                       ###
1086###############################################################################
1087
1088class ThreeBytesLenField(FieldLenField):
1089    def __init__(self, name, default,  length_of=None, adjust=lambda pkt, x:x):
1090        FieldLenField.__init__(self, name, default, length_of=length_of,
1091                               fmt='!I', adjust=adjust)
1092    def i2repr(self, pkt, x):
1093        if x is None:
1094            return 0
1095        return repr(self.i2h(pkt,x))
1096    def addfield(self, pkt, s, val):
1097        return s+struct.pack(self.fmt, self.i2m(pkt,val))[1:4]
1098    def getfield(self, pkt, s):
1099        return  s[3:], self.m2i(pkt, struct.unpack(self.fmt, b"\x00"+s[:3])[0])
1100
1101_cert_status_cls  = { 1: OCSP_Response }
1102
1103class _StatusField(PacketField):
1104    def m2i(self, pkt, m):
1105        idtype = pkt.status_type
1106        cls = self.cls
1107        if idtype in _cert_status_cls:
1108            cls = _cert_status_cls[idtype]
1109        return cls(m)
1110
1111class TLSCertificateStatus(_TLSHandshake):
1112    name = "TLS Handshake - Certificate Status"
1113    fields_desc = [ ByteEnumField("msgtype", 22, _tls_handshake_type),
1114                    ThreeBytesField("msglen", None),
1115                    ByteEnumField("status_type", 1, _cert_status_type),
1116                    ThreeBytesLenField("responselen", None,
1117                                       length_of="response"),
1118                    _StatusField("response", None, Raw) ]
1119
1120
1121###############################################################################
1122### SupplementalData                                                        ###
1123###############################################################################
1124
1125class SupDataEntry(Packet):
1126    name = "Supplemental Data Entry - Generic"
1127    fields_desc = [ ShortField("sdtype", None),
1128                    FieldLenField("len", None, length_of="data"),
1129                    StrLenField("data", "",
1130                                length_from=lambda pkt:pkt.len) ]
1131    def guess_payload_class(self, p):
1132        return Padding
1133
1134class UserMappingData(Packet):
1135    name = "User Mapping Data"
1136    fields_desc = [ ByteField("version", None),
1137                    FieldLenField("len", None, length_of="data"),
1138                    StrLenField("data", "",
1139                                length_from=lambda pkt: pkt.len)]
1140    def guess_payload_class(self, p):
1141        return Padding
1142
1143class SupDataEntryUM(Packet):
1144    name = "Supplemental Data Entry - User Mapping"
1145    fields_desc = [ ShortField("sdtype", None),
1146                    FieldLenField("len", None, length_of="data",
1147                                  adjust=lambda pkt, x: x+2),
1148                    FieldLenField("dlen", None, length_of="data"),
1149                    PacketListField("data", [], UserMappingData,
1150                                    length_from=lambda pkt:pkt.dlen) ]
1151    def guess_payload_class(self, p):
1152        return Padding
1153
1154class TLSSupplementalData(_TLSHandshake):
1155    name = "TLS Handshake - Supplemental Data"
1156    fields_desc = [ ByteEnumField("msgtype", 23, _tls_handshake_type),
1157                    ThreeBytesField("msglen", None),
1158                    ThreeBytesLenField("sdatalen", None, length_of="sdata"),
1159                    PacketListField("sdata", [], SupDataEntry,
1160                                    length_from=lambda pkt: pkt.sdatalen) ]
1161
1162
1163###############################################################################
1164### NewSessionTicket                                                        ###
1165###############################################################################
1166
1167class TLSNewSessionTicket(_TLSHandshake):
1168    """
1169    XXX When knowing the right secret, we should be able to read the ticket.
1170    """
1171    name = "TLS Handshake - New Session Ticket"
1172    fields_desc = [ ByteEnumField("msgtype", 4, _tls_handshake_type),
1173                    ThreeBytesField("msglen", None),
1174                    IntField("lifetime", 0xffffffff),
1175                    FieldLenField("ticketlen", None, length_of="ticket"),
1176                    StrLenField("ticket", "",
1177                                length_from=lambda pkt: pkt.ticketlen) ]
1178
1179    @classmethod
1180    def dispatch_hook(cls, _pkt=None, *args, **kargs):
1181        s = kargs.get("tls_session", None)
1182        if s and s.tls_version >= 0x0304:
1183            return TLS13NewSessionTicket
1184        return TLSNewSessionTicket
1185
1186    def post_dissection_tls_session_update(self, msg_str):
1187        self.tls_session_update(msg_str)
1188        if self.tls_session.connection_end == "client":
1189            self.tls_session.client_session_ticket = self.ticket
1190
1191
1192class TLS13NewSessionTicket(_TLSHandshake):
1193    """
1194    Uncomment the TicketField line for parsing a RFC 5077 ticket.
1195    """
1196    name = "TLS Handshake - New Session Ticket"
1197    fields_desc = [ ByteEnumField("msgtype", 4, _tls_handshake_type),
1198                    ThreeBytesField("msglen", None),
1199                    IntField("ticket_lifetime", 0xffffffff),
1200                    IntField("ticket_age_add", 0),
1201                    FieldLenField("ticketlen", None, length_of="ticket"),
1202                    #TicketField("ticket", "",
1203                    StrLenField("ticket", "",
1204                                length_from=lambda pkt: pkt.ticketlen),
1205                    _ExtensionsLenField("extlen", None, length_of="ext"),
1206                    _ExtensionsField("ext", None,
1207                                 length_from=lambda pkt: (pkt.msglen -
1208                                                          (pkt.ticketlen or 0) -
1209                                                          12)) ]
1210
1211    def post_dissection_tls_session_update(self, msg_str):
1212        self.tls_session_update(msg_str)
1213        if self.tls_session.connection_end == "client":
1214            self.tls_session.client_session_ticket = self.ticket
1215
1216
1217###############################################################################
1218### All handshake messages defined in this module                           ###
1219###############################################################################
1220
1221_tls_handshake_cls = { 0: TLSHelloRequest,          1: TLSClientHello,
1222                       2: TLSServerHello,           3: TLSHelloVerifyRequest,
1223                       4: TLSNewSessionTicket,      6: TLSHelloRetryRequest,
1224                       8: TLSEncryptedExtensions,   11: TLSCertificate,
1225                       12: TLSServerKeyExchange,    13: TLSCertificateRequest,
1226                       14: TLSServerHelloDone,      15: TLSCertificateVerify,
1227                       16: TLSClientKeyExchange,    20: TLSFinished,
1228                       21: TLSCertificateURL,       22: TLSCertificateStatus,
1229                       23: TLSSupplementalData }
1230
1231