1## This file is part of Scapy
2## Copyright (C) 2007, 2008, 2009 Arnaud Ebalard
3##               2015, 2016, 2017 Maxence Tury
4## This program is published under a GPLv2 license
5
6"""
7TLS key exchange logic.
8"""
9
10from __future__ import absolute_import
11import math
12
13from scapy.config import conf, crypto_validator
14from scapy.error import warning
15from scapy.fields import *
16from scapy.compat import orb
17from scapy.packet import Packet, Raw, Padding
18from scapy.layers.tls.cert import PubKeyRSA, PrivKeyRSA
19from scapy.layers.tls.session import _GenericTLSSessionInheritance
20from scapy.layers.tls.basefields import _tls_version, _TLSClientVersionField
21from scapy.layers.tls.crypto.pkcs1 import pkcs_i2osp, pkcs_os2ip
22from scapy.layers.tls.crypto.groups import _ffdh_groups, _tls_named_curves
23import scapy.modules.six as six
24
25if conf.crypto_valid:
26    from cryptography.hazmat.backends import default_backend
27    from cryptography.hazmat.primitives.asymmetric import dh, ec
28
29
30###############################################################################
31### Common Fields                                                           ###
32###############################################################################
33
34_tls_hash_sig = { 0x0000: "none+anon",    0x0001: "none+rsa",
35                  0x0002: "none+dsa",     0x0003: "none+ecdsa",
36                  0x0100: "md5+anon",     0x0101: "md5+rsa",
37                  0x0102: "md5+dsa",      0x0103: "md5+ecdsa",
38                  0x0200: "sha1+anon",    0x0201: "sha1+rsa",
39                  0x0202: "sha1+dsa",     0x0203: "sha1+ecdsa",
40                  0x0300: "sha224+anon",  0x0301: "sha224+rsa",
41                  0x0302: "sha224+dsa",   0x0303: "sha224+ecdsa",
42                  0x0400: "sha256+anon",  0x0401: "sha256+rsa",
43                  0x0402: "sha256+dsa",   0x0403: "sha256+ecdsa",
44                  0x0500: "sha384+anon",  0x0501: "sha384+rsa",
45                  0x0502: "sha384+dsa",   0x0503: "sha384+ecdsa",
46                  0x0600: "sha512+anon",  0x0601: "sha512+rsa",
47                  0x0602: "sha512+dsa",   0x0603: "sha512+ecdsa",
48                  0x0804: "sha256+rsapss",
49                  0x0805: "sha384+rsapss",
50                  0x0806: "sha512+rsapss",
51                  0x0807: "ed25519",
52                  0x0808: "ed448" }
53
54
55def phantom_mode(pkt):
56    """
57    We expect this. If tls_version is not set, this means we did not process
58    any complete ClientHello, so we're most probably reading/building a
59    signature_algorithms extension, hence we cannot be in phantom_mode.
60    However, if the tls_version has been set, we test for TLS 1.2.
61    """
62    if not pkt.tls_session:
63        return False
64    if not pkt.tls_session.tls_version:
65        return False
66    return pkt.tls_session.tls_version < 0x0303
67
68def phantom_decorate(f, get_or_add):
69    """
70    Decorator for version-dependent fields.
71    If get_or_add is True (means get), we return s, self.phantom_value.
72    If it is False (means add), we return s.
73    """
74    def wrapper(*args):
75        self, pkt, s = args[:3]
76        if phantom_mode(pkt):
77            if get_or_add:
78                return s, self.phantom_value
79            return s
80        return f(*args)
81    return wrapper
82
83class SigAndHashAlgField(EnumField):
84    """Used in _TLSSignature."""
85    phantom_value = None
86    getfield = phantom_decorate(EnumField.getfield, True)
87    addfield = phantom_decorate(EnumField.addfield, False)
88
89class SigAndHashAlgsLenField(FieldLenField):
90    """Used in TLS_Ext_SignatureAlgorithms and TLSCertificateResquest."""
91    phantom_value = 0
92    getfield = phantom_decorate(FieldLenField.getfield, True)
93    addfield = phantom_decorate(FieldLenField.addfield, False)
94
95class SigAndHashAlgsField(FieldListField):
96    """Used in TLS_Ext_SignatureAlgorithms and TLSCertificateResquest."""
97    phantom_value = []
98    getfield = phantom_decorate(FieldListField.getfield, True)
99    addfield = phantom_decorate(FieldListField.addfield, False)
100
101
102class SigLenField(FieldLenField):
103    """There is a trick for SSLv2, which uses implicit lengths..."""
104    def getfield(self, pkt, s):
105        v = pkt.tls_session.tls_version
106        if v and v < 0x0300:
107            return s, None
108        return super(SigLenField, self).getfield(pkt, s)
109
110    def addfield(self, pkt, s, val):
111        """With SSLv2 you will never be able to add a sig_len."""
112        v = pkt.tls_session.tls_version
113        if v and v < 0x0300:
114            return s
115        return super(SigLenField, self).addfield(pkt, s, val)
116
117class SigValField(StrLenField):
118    """There is a trick for SSLv2, which uses implicit lengths..."""
119    def getfield(self, pkt, m):
120        s = pkt.tls_session
121        if s.tls_version and s.tls_version < 0x0300:
122            if len(s.client_certs) > 0:
123                sig_len = s.client_certs[0].pubKey.pubkey.key_size // 8
124            else:
125                warning("No client certificate provided. "
126                        "We're making a wild guess about the signature size.")
127                sig_len = 256
128            return m[sig_len:], self.m2i(pkt, m[:sig_len])
129        return super(SigValField, self).getfield(pkt, m)
130
131
132class _TLSSignature(_GenericTLSSessionInheritance):
133    """
134    Prior to TLS 1.2, digitally-signed structure implicitly used the
135    concatenation of a MD5 hash and a SHA-1 hash.
136    Then TLS 1.2 introduced explicit SignatureAndHashAlgorithms,
137    i.e. couples of (hash_alg, sig_alg). See RFC 5246, section 7.4.1.4.1.
138
139    By default, the _TLSSignature implements the TLS 1.2 scheme,
140    but if it is provided a TLS context with a tls_version < 0x0303
141    at initialization, it will fall back to the implicit signature.
142    Even more, the 'sig_len' field won't be used with SSLv2.
143
144    #XXX 'sig_alg' should be set in __init__ depending on the context.
145    """
146    name = "TLS Digital Signature"
147    fields_desc = [ SigAndHashAlgField("sig_alg", 0x0401, _tls_hash_sig),
148                    SigLenField("sig_len", None, fmt="!H",
149                                length_of="sig_val"),
150                    SigValField("sig_val", None,
151                                length_from=lambda pkt: pkt.sig_len) ]
152
153    def __init__(self, *args, **kargs):
154        super(_TLSSignature, self).__init__(*args, **kargs)
155        if (self.tls_session and
156            self.tls_session.tls_version and
157            self.tls_session.tls_version < 0x0303):
158            self.sig_alg = None
159
160    def _update_sig(self, m, key):
161        """
162        Sign 'm' with the PrivKey 'key' and update our own 'sig_val'.
163        Note that, even when 'sig_alg' is not None, we use the signature scheme
164        of the PrivKey (neither do we care to compare the both of them).
165        """
166        if self.sig_alg is None:
167            if self.tls_session.tls_version >= 0x0300:
168                self.sig_val = key.sign(m, t='pkcs', h='md5-sha1')
169            else:
170                self.sig_val = key.sign(m, t='pkcs', h='md5')
171        else:
172            h, sig = _tls_hash_sig[self.sig_alg].split('+')
173            if sig.endswith('pss'):
174                t = "pss"
175            else:
176                t = "pkcs"
177            self.sig_val = key.sign(m, t=t, h=h)
178
179    def _verify_sig(self, m, cert):
180        """
181        Verify that our own 'sig_val' carries the signature of 'm' by the
182        key associated to the Cert 'cert'.
183        """
184        if self.sig_val:
185            if self.sig_alg:
186                h, sig = _tls_hash_sig[self.sig_alg].split('+')
187                if sig.endswith('pss'):
188                    t = "pss"
189                else:
190                    t = "pkcs"
191                return cert.verify(m, self.sig_val, t=t, h=h)
192            else:
193                if self.tls_session.tls_version >= 0x0300:
194                    return cert.verify(m, self.sig_val, t='pkcs', h='md5-sha1')
195                else:
196                    return cert.verify(m, self.sig_val, t='pkcs', h='md5')
197        return False
198
199    def guess_payload_class(self, p):
200        return Padding
201
202class _TLSSignatureField(PacketField):
203    """
204    Used for 'digitally-signed struct' in several ServerKeyExchange,
205    and also in CertificateVerify. We can handle the anonymous case.
206    """
207    __slots__ = ["length_from"]
208    def __init__(self, name, default, length_from=None, remain=0):
209        self.length_from = length_from
210        PacketField.__init__(self, name, default, _TLSSignature, remain=remain)
211
212    def m2i(self, pkt, m):
213        l = self.length_from(pkt)
214        if l == 0:
215           return None
216        return _TLSSignature(m, tls_session=pkt.tls_session)
217
218    def getfield(self, pkt, s):
219        i = self.m2i(pkt, s)
220        if i is None:
221            return s, None
222        remain = b""
223        if conf.padding_layer in i:
224            r = i[conf.padding_layer]
225            del(r.underlayer.payload)
226            remain = r.load
227        return remain, i
228
229
230class _TLSServerParamsField(PacketField):
231    """
232    This is a dispatcher for the Server*DHParams below, used in
233    TLSServerKeyExchange and based on the key_exchange.server_kx_msg_cls.
234    When this cls is None, it means that we should not see a ServerKeyExchange,
235    so we grab everything within length_from and make it available using Raw.
236
237    When the context has not been set (e.g. when no ServerHello was parsed or
238    dissected beforehand), we (kinda) clumsily set the cls by trial and error.
239    XXX We could use Serv*DHParams.check_params() once it has been implemented.
240    """
241    __slots__ = ["length_from"]
242    def __init__(self, name, default, length_from=None, remain=0):
243        self.length_from = length_from
244        PacketField.__init__(self, name, default, None, remain=remain)
245
246    def m2i(self, pkt, m):
247        s = pkt.tls_session
248        l = self.length_from(pkt)
249        if s.prcs:
250            cls = s.prcs.key_exchange.server_kx_msg_cls(m)
251            if cls is None:
252                return None, Raw(m[:l])/Padding(m[l:])
253            return cls(m, tls_session=s)
254        else:
255            try:
256                p = ServerDHParams(m, tls_session=s)
257                if pkcs_os2ip(p.load[:2]) not in _tls_hash_sig:
258                    raise Exception
259                return p
260            except:
261                cls = _tls_server_ecdh_cls_guess(m)
262                p = cls(m, tls_session=s)
263                if pkcs_os2ip(p.load[:2]) not in _tls_hash_sig:
264                    return None, Raw(m[:l])/Padding(m[l:])
265                return p
266
267
268###############################################################################
269### Server Key Exchange parameters & value                                  ###
270###############################################################################
271
272### Finite Field Diffie-Hellman
273
274class ServerDHParams(_GenericTLSSessionInheritance):
275    """
276    ServerDHParams for FFDH-based key exchanges, as defined in RFC 5246/7.4.3.
277
278    Either with .fill_missing() or .post_dissection(), the server_kx_privkey or
279    server_kx_pubkey of the TLS context are updated according to the
280    parsed/assembled values. It is the user's responsibility to store and
281    restore the original values if he wants to keep them. For instance, this
282    could be done between the writing of a ServerKeyExchange and the receiving
283    of a ClientKeyExchange (which includes secret generation).
284    """
285    name = "Server FFDH parameters"
286    fields_desc = [ FieldLenField("dh_plen", None, length_of="dh_p"),
287                    StrLenField("dh_p", "",
288                                length_from=lambda pkt: pkt.dh_plen),
289                    FieldLenField("dh_glen", None, length_of="dh_g"),
290                    StrLenField("dh_g", "",
291                                length_from=lambda pkt: pkt.dh_glen),
292                    FieldLenField("dh_Yslen", None, length_of="dh_Ys"),
293                    StrLenField("dh_Ys", "",
294                                length_from=lambda pkt: pkt.dh_Yslen) ]
295
296    @crypto_validator
297    def fill_missing(self):
298        """
299        We do not want TLSServerKeyExchange.build() to overload and recompute
300        things everytime it is called. This method can be called specifically
301        to have things filled in a smart fashion.
302
303        Note that we do not expect default_params.g to be more than 0xff.
304        """
305        s = self.tls_session
306
307        default_params = _ffdh_groups['modp2048'][0].parameter_numbers()
308        default_mLen = _ffdh_groups['modp2048'][1]
309
310        if not self.dh_p:
311            self.dh_p = pkcs_i2osp(default_params.p, default_mLen//8)
312        if self.dh_plen is None:
313            self.dh_plen = len(self.dh_p)
314
315        if not self.dh_g:
316            self.dh_g = pkcs_i2osp(default_params.g, 1)
317        if self.dh_glen is None:
318            self.dh_glen = 1
319
320        p = pkcs_os2ip(self.dh_p)
321        g = pkcs_os2ip(self.dh_g)
322        real_params = dh.DHParameterNumbers(p, g).parameters(default_backend())
323
324        if not self.dh_Ys:
325            s.server_kx_privkey = real_params.generate_private_key()
326            pubkey = s.server_kx_privkey.public_key()
327            y = pubkey.public_numbers().y
328            self.dh_Ys = pkcs_i2osp(y, pubkey.key_size//8)
329        # else, we assume that the user wrote the server_kx_privkey by himself
330        if self.dh_Yslen is None:
331            self.dh_Yslen = len(self.dh_Ys)
332
333        if not s.client_kx_ffdh_params:
334            s.client_kx_ffdh_params = real_params
335
336    @crypto_validator
337    def register_pubkey(self):
338        """
339        XXX Check that the pubkey received is in the group.
340        """
341        p = pkcs_os2ip(self.dh_p)
342        g = pkcs_os2ip(self.dh_g)
343        pn = dh.DHParameterNumbers(p, g)
344
345        y = pkcs_os2ip(self.dh_Ys)
346        public_numbers = dh.DHPublicNumbers(y, pn)
347
348        s = self.tls_session
349        s.server_kx_pubkey = public_numbers.public_key(default_backend())
350
351        if not s.client_kx_ffdh_params:
352            s.client_kx_ffdh_params = pn.parameters(default_backend())
353
354    def post_dissection(self, r):
355        try:
356            self.register_pubkey()
357        except ImportError:
358            pass
359
360    def guess_payload_class(self, p):
361        """
362        The signature after the params gets saved as Padding.
363        This way, the .getfield() which _TLSServerParamsField inherits
364        from PacketField will return the signature remain as expected.
365        """
366        return Padding
367
368
369### Elliptic Curve Diffie-Hellman
370
371_tls_ec_curve_types = { 1: "explicit_prime",
372                        2: "explicit_char2",
373                        3: "named_curve" }
374
375_tls_ec_basis_types = { 0: "ec_basis_trinomial", 1: "ec_basis_pentanomial"}
376
377class ECCurvePkt(Packet):
378    name = "Elliptic Curve"
379    fields_desc = [ FieldLenField("alen", None, length_of="a", fmt="B"),
380                    StrLenField("a", "", length_from = lambda pkt: pkt.alen),
381                    FieldLenField("blen", None, length_of="b", fmt="B"),
382                    StrLenField("b", "", length_from = lambda pkt: pkt.blen) ]
383
384
385## Char2 Curves
386
387class ECTrinomialBasis(Packet):
388    name = "EC Trinomial Basis"
389    val = 0
390    fields_desc = [ FieldLenField("klen", None, length_of="k", fmt="B"),
391                    StrLenField("k", "", length_from = lambda pkt: pkt.klen) ]
392    def guess_payload_class(self, p):
393        return Padding
394
395class ECPentanomialBasis(Packet):
396    name = "EC Pentanomial Basis"
397    val = 1
398    fields_desc = [ FieldLenField("k1len", None, length_of="k1", fmt="B"),
399                    StrLenField("k1", "", length_from=lambda pkt: pkt.k1len),
400                    FieldLenField("k2len", None, length_of="k2", fmt="B"),
401                    StrLenField("k2", "", length_from=lambda pkt: pkt.k2len),
402                    FieldLenField("k3len", None, length_of="k3", fmt="B"),
403                    StrLenField("k3", "", length_from=lambda pkt: pkt.k3len) ]
404    def guess_payload_class(self, p):
405        return Padding
406
407_tls_ec_basis_cls = { 0: ECTrinomialBasis, 1: ECPentanomialBasis}
408
409class _ECBasisTypeField(ByteEnumField):
410    __slots__ = ["basis_type_of"]
411    def __init__(self, name, default, enum, basis_type_of, remain=0):
412        self.basis_type_of = basis_type_of
413        EnumField.__init__(self, name, default, enum, "B")
414
415    def i2m(self, pkt, x):
416        if x is None:
417            val = 0
418            fld,fval = pkt.getfield_and_val(self.basis_type_of)
419            x = fld.i2basis_type(pkt, fval)
420        return x
421
422class _ECBasisField(PacketField):
423    __slots__ = ["clsdict", "basis_type_from"]
424    def __init__(self, name, default, basis_type_from, clsdict, remain=0):
425        self.clsdict = clsdict
426        self.basis_type_from = basis_type_from
427        PacketField.__init__(self, name, default, None, remain=remain)
428
429    def m2i(self, pkt, m):
430        basis = self.basis_type_from(pkt)
431        cls = self.clsdict[basis]
432        return cls(m)
433
434    def i2basis_type(self, pkt, x):
435        val = 0
436        try:
437            val = x.val
438        except:
439            pass
440        return val
441
442
443## Distinct ECParameters
444##
445## To support the different ECParameters structures defined in Sect. 5.4 of
446## RFC 4492, we define 3 separates classes for implementing the 3 associated
447## ServerECDHParams: ServerECDHNamedCurveParams, ServerECDHExplicitPrimeParams
448## and ServerECDHExplicitChar2Params (support for this one is only partial).
449## The most frequent encounter of the 3 is (by far) ServerECDHNamedCurveParams.
450
451class ServerECDHExplicitPrimeParams(_GenericTLSSessionInheritance):
452    """
453    We provide parsing abilities for ExplicitPrimeParams, but there is no
454    support from the cryptography library, hence no context operations.
455    """
456    name = "Server ECDH parameters - Explicit Prime"
457    fields_desc = [ ByteEnumField("curve_type", 1, _tls_ec_curve_types),
458                    FieldLenField("plen", None, length_of="p", fmt="B"),
459                    StrLenField("p", "", length_from=lambda pkt: pkt.plen),
460                    PacketField("curve", None, ECCurvePkt),
461                    FieldLenField("baselen", None, length_of="base", fmt="B"),
462                    StrLenField("base", "",
463                                length_from=lambda pkt: pkt.baselen),
464                    FieldLenField("orderlen", None,
465                                  length_of="order", fmt="B"),
466                    StrLenField("order", "",
467                                length_from=lambda pkt: pkt.orderlen),
468                    FieldLenField("cofactorlen", None,
469                                  length_of="cofactor", fmt="B"),
470                    StrLenField("cofactor", "",
471                                length_from=lambda pkt: pkt.cofactorlen),
472                    FieldLenField("pointlen", None,
473                                  length_of="point", fmt="B"),
474                    StrLenField("point", "",
475                                length_from=lambda pkt: pkt.pointlen) ]
476
477    def fill_missing(self):
478        """
479        Note that if it is not set by the user, the cofactor will always
480        be 1. It is true for most, but not all, TLS elliptic curves.
481        """
482        if self.curve_type is None:
483            self.curve_type = _tls_ec_curve_types["explicit_prime"]
484
485    def guess_payload_class(self, p):
486        return Padding
487
488
489class ServerECDHExplicitChar2Params(_GenericTLSSessionInheritance):
490    """
491    We provide parsing abilities for Char2Params, but there is no
492    support from the cryptography library, hence no context operations.
493    """
494    name = "Server ECDH parameters - Explicit Char2"
495    fields_desc = [ ByteEnumField("curve_type", 2, _tls_ec_curve_types),
496                    ShortField("m", None),
497                    _ECBasisTypeField("basis_type", None,
498                                      _tls_ec_basis_types, "basis"),
499                    _ECBasisField("basis", ECTrinomialBasis(),
500                                  lambda pkt: pkt.basis_type,
501                                  _tls_ec_basis_cls),
502                    PacketField("curve", ECCurvePkt(), ECCurvePkt),
503                    FieldLenField("baselen", None, length_of="base", fmt="B"),
504                    StrLenField("base", "",
505                                length_from = lambda pkt: pkt.baselen),
506                    ByteField("order", None),
507                    ByteField("cofactor", None),
508                    FieldLenField("pointlen", None,
509                                  length_of="point", fmt="B"),
510                    StrLenField("point", "",
511                                length_from = lambda pkt: pkt.pointlen) ]
512
513    def fill_missing(self):
514        if self.curve_type is None:
515            self.curve_type = _tls_ec_curve_types["explicit_char2"]
516
517    def guess_payload_class(self, p):
518        return Padding
519
520
521class ServerECDHNamedCurveParams(_GenericTLSSessionInheritance):
522    name = "Server ECDH parameters - Named Curve"
523    fields_desc = [ ByteEnumField("curve_type", 3, _tls_ec_curve_types),
524                    ShortEnumField("named_curve", None, _tls_named_curves),
525                    FieldLenField("pointlen", None,
526                                  length_of="point", fmt="B"),
527                    StrLenField("point", None,
528                                length_from = lambda pkt: pkt.pointlen) ]
529
530    @crypto_validator
531    def fill_missing(self):
532        """
533        We do not want TLSServerKeyExchange.build() to overload and recompute
534        things everytime it is called. This method can be called specifically
535        to have things filled in a smart fashion.
536
537        XXX We should account for the point_format (before 'point' filling).
538        """
539        s = self.tls_session
540
541        if self.curve_type is None:
542            self.curve_type = _tls_ec_curve_types["named_curve"]
543
544        if self.named_curve is None:
545            curve = ec.SECP256R1()
546            s.server_kx_privkey = ec.generate_private_key(curve,
547                                                          default_backend())
548            curve_id = 0
549            for cid, name in six.iteritems(_tls_named_curves):
550                if name == curve.name:
551                    curve_id = cid
552                    break
553            self.named_curve = curve_id
554        else:
555            curve_name = _tls_named_curves.get(self.named_curve)
556            if curve_name is None:
557                # this fallback is arguable
558                curve = ec.SECP256R1()
559            else:
560                curve_cls = ec._CURVE_TYPES.get(curve_name)
561                if curve_cls is None:
562                    # this fallback is arguable
563                    curve = ec.SECP256R1()
564                else:
565                    curve = curve_cls()
566            s.server_kx_privkey = ec.generate_private_key(curve,
567                                                          default_backend())
568
569        if self.point is None:
570            pubkey = s.server_kx_privkey.public_key()
571            self.point = pubkey.public_numbers().encode_point()
572        # else, we assume that the user wrote the server_kx_privkey by himself
573        if self.pointlen is None:
574            self.pointlen = len(self.point)
575
576        if not s.client_kx_ecdh_params:
577            s.client_kx_ecdh_params = curve
578
579    @crypto_validator
580    def register_pubkey(self):
581        """
582        XXX Support compressed point format.
583        XXX Check that the pubkey received is on the curve.
584        """
585        #point_format = 0
586        #if self.point[0] in [b'\x02', b'\x03']:
587        #    point_format = 1
588
589        curve_name = _tls_named_curves[self.named_curve]
590        curve = ec._CURVE_TYPES[curve_name]()
591        import_point = ec.EllipticCurvePublicNumbers.from_encoded_point
592        pubnum = import_point(curve, self.point)
593        s = self.tls_session
594        s.server_kx_pubkey = pubnum.public_key(default_backend())
595
596        if not s.client_kx_ecdh_params:
597            s.client_kx_ecdh_params = curve
598
599    def post_dissection(self, r):
600        try:
601            self.register_pubkey()
602        except ImportError:
603            pass
604
605    def guess_payload_class(self, p):
606        return Padding
607
608
609_tls_server_ecdh_cls = { 1: ServerECDHExplicitPrimeParams,
610                         2: ServerECDHExplicitChar2Params,
611                         3: ServerECDHNamedCurveParams }
612
613def _tls_server_ecdh_cls_guess(m):
614    if not m:
615        return None
616    curve_type = orb(m[0])
617    return _tls_server_ecdh_cls.get(curve_type, None)
618
619
620### RSA Encryption (export)
621
622class ServerRSAParams(_GenericTLSSessionInheritance):
623    """
624    Defined for RSA_EXPORT kx : it enables servers to share RSA keys shorter
625    than their principal {>512}-bit key, when it is not allowed for kx.
626
627    This should not appear in standard RSA kx negotiation, as the key
628    has already been advertised in the Certificate message.
629    """
630    name = "Server RSA_EXPORT parameters"
631    fields_desc = [ FieldLenField("rsamodlen", None, length_of="rsamod"),
632                    StrLenField("rsamod", "",
633                                length_from = lambda pkt: pkt.rsamodlen),
634                    FieldLenField("rsaexplen", None, length_of="rsaexp"),
635                    StrLenField("rsaexp", "",
636                                length_from = lambda pkt: pkt.rsaexplen) ]
637
638    @crypto_validator
639    def fill_missing(self):
640        k = PrivKeyRSA()
641        k.fill_and_store(modulusLen=512)
642        self.tls_session.server_tmp_rsa_key = k
643        pubNum = k.pubkey.public_numbers()
644
645        if not self.rsamod:
646            self.rsamod = pkcs_i2osp(pubNum.n, k.pubkey.key_size//8)
647        if self.rsamodlen is None:
648            self.rsamodlen = len(self.rsamod)
649
650        rsaexplen = math.ceil(math.log(pubNum.e)/math.log(2)/8.)
651        if not self.rsaexp:
652            self.rsaexp = pkcs_i2osp(pubNum.e, rsaexplen)
653        if self.rsaexplen is None:
654            self.rsaexplen = len(self.rsaexp)
655
656    @crypto_validator
657    def register_pubkey(self):
658        mLen = self.rsamodlen
659        m    = self.rsamod
660        e    = self.rsaexp
661        self.tls_session.server_tmp_rsa_key = PubKeyRSA((e, m, mLen))
662
663    def post_dissection(self, pkt):
664        try:
665            self.register_pubkey()
666        except ImportError:
667            pass
668
669    def guess_payload_class(self, p):
670        return Padding
671
672
673### Pre-Shared Key
674
675class ServerPSKParams(Packet):
676    """
677    XXX We provide some parsing abilities for ServerPSKParams, but the
678    context operations have not been implemented yet. See RFC 4279.
679    Note that we do not cover the (EC)DHE_PSK key exchange,
680    which should contain a Server*DHParams after 'psk_identity_hint'.
681    """
682    name = "Server PSK parameters"
683    fields_desc = [ FieldLenField("psk_identity_hint_len", None,
684                                  length_of="psk_identity_hint", fmt="!H"),
685                    StrLenField("psk_identity_hint", "",
686                        length_from=lambda pkt: pkt.psk_identity_hint_len) ]
687
688    def fill_missing(self):
689        pass
690
691    def post_dissection(self, pkt):
692        pass
693
694    def guess_payload_class(self, p):
695        return Padding
696
697
698###############################################################################
699### Client Key Exchange value                                               ###
700###############################################################################
701
702### FFDH/ECDH
703
704class ClientDiffieHellmanPublic(_GenericTLSSessionInheritance):
705    """
706    If the user provides a value for dh_Yc attribute, we assume he will set
707    the pms and ms accordingly and trigger the key derivation on his own.
708
709    XXX As specified in 7.4.7.2. of RFC 4346, we should distinguish the needs
710    for implicit or explicit value depending on availability of DH parameters
711    in *client* certificate. For now we can only do ephemeral/explicit DH.
712    """
713    name = "Client DH Public Value"
714    fields_desc = [ FieldLenField("dh_Yclen", None, length_of="dh_Yc"),
715                    StrLenField("dh_Yc", "",
716                                length_from=lambda pkt: pkt.dh_Yclen) ]
717
718    @crypto_validator
719    def fill_missing(self):
720        s = self.tls_session
721        params = s.client_kx_ffdh_params
722        s.client_kx_privkey = params.generate_private_key()
723        pubkey = s.client_kx_privkey.public_key()
724        y = pubkey.public_numbers().y
725        self.dh_Yc = pkcs_i2osp(y, pubkey.key_size//8)
726
727        if s.client_kx_privkey and s.server_kx_pubkey:
728            pms = s.client_kx_privkey.exchange(s.server_kx_pubkey)
729            s.pre_master_secret = pms
730            s.compute_ms_and_derive_keys()
731
732    def post_build(self, pkt, pay):
733        if not self.dh_Yc:
734            try:
735                self.fill_missing()
736            except ImportError:
737                pass
738        if self.dh_Yclen is None:
739            self.dh_Yclen = len(self.dh_Yc)
740        return pkcs_i2osp(self.dh_Yclen, 2) + self.dh_Yc + pay
741
742    def post_dissection(self, m):
743        """
744        First we update the client DHParams. Then, we try to update the server
745        DHParams generated during Server*DHParams building, with the shared
746        secret. Finally, we derive the session keys and update the context.
747        """
748        s = self.tls_session
749
750        # if there are kx params and keys, we assume the crypto library is ok
751        if s.client_kx_ffdh_params:
752            y = pkcs_os2ip(self.dh_Yc)
753            param_numbers = s.client_kx_ffdh_params.parameter_numbers()
754            public_numbers = dh.DHPublicNumbers(y, param_numbers)
755            s.client_kx_pubkey = public_numbers.public_key(default_backend())
756
757        if s.server_kx_privkey and s.client_kx_pubkey:
758            ZZ = s.server_kx_privkey.exchange(s.client_kx_pubkey)
759            s.pre_master_secret = ZZ
760            s.compute_ms_and_derive_keys()
761
762    def guess_payload_class(self, p):
763        return Padding
764
765class ClientECDiffieHellmanPublic(_GenericTLSSessionInheritance):
766    """
767    Note that the 'len' field is 1 byte longer than with the previous class.
768    """
769    name = "Client ECDH Public Value"
770    fields_desc = [ FieldLenField("ecdh_Yclen", None,
771                                  length_of="ecdh_Yc", fmt="B"),
772                    StrLenField("ecdh_Yc", "",
773                                length_from=lambda pkt: pkt.ecdh_Yclen)]
774
775    @crypto_validator
776    def fill_missing(self):
777        s = self.tls_session
778        params = s.client_kx_ecdh_params
779        s.client_kx_privkey = ec.generate_private_key(params,
780                                                      default_backend())
781        pubkey = s.client_kx_privkey.public_key()
782        x = pubkey.public_numbers().x
783        y = pubkey.public_numbers().y
784        self.ecdh_Yc = (b"\x04" +
785                        pkcs_i2osp(x, params.key_size//8) +
786                        pkcs_i2osp(y, params.key_size//8))
787
788        if s.client_kx_privkey and s.server_kx_pubkey:
789            pms = s.client_kx_privkey.exchange(ec.ECDH(), s.server_kx_pubkey)
790            s.pre_master_secret = pms
791            s.compute_ms_and_derive_keys()
792
793    def post_build(self, pkt, pay):
794        if not self.ecdh_Yc:
795            try:
796                self.fill_missing()
797            except ImportError:
798                pass
799        if self.ecdh_Yclen is None:
800            self.ecdh_Yclen = len(self.ecdh_Yc)
801        return pkcs_i2osp(self.ecdh_Yclen, 1) + self.ecdh_Yc + pay
802
803    def post_dissection(self, m):
804        s = self.tls_session
805
806        # if there are kx params and keys, we assume the crypto library is ok
807        if s.client_kx_ecdh_params:
808            import_point = ec.EllipticCurvePublicNumbers.from_encoded_point
809            pub_num = import_point(s.client_kx_ecdh_params, self.ecdh_Yc)
810            s.client_kx_pubkey = pub_num.public_key(default_backend())
811
812        if s.server_kx_privkey and s.client_kx_pubkey:
813            ZZ = s.server_kx_privkey.exchange(ec.ECDH(), s.client_kx_pubkey)
814            s.pre_master_secret = ZZ
815            s.compute_ms_and_derive_keys()
816
817
818### RSA Encryption (standard & export)
819
820class _UnEncryptedPreMasterSecret(Raw):
821    """
822    When the content of an EncryptedPreMasterSecret could not be deciphered,
823    we use this class to represent the encrypted data.
824    """
825    name = "RSA Encrypted PreMaster Secret (protected)"
826    def __init__(self, *args, **kargs):
827        if 'tls_session' in kargs:
828            del(kargs['tls_session'])
829        return super(_UnEncryptedPreMasterSecret, self).__init__(*args, **kargs)
830
831class EncryptedPreMasterSecret(_GenericTLSSessionInheritance):
832    """
833    Pay attention to implementation notes in section 7.4.7.1 of RFC 5246.
834    """
835    name = "RSA Encrypted PreMaster Secret"
836    fields_desc = [ _TLSClientVersionField("client_version", None,
837                                           _tls_version),
838                    StrFixedLenField("random", None, 46) ]
839
840    @classmethod
841    def dispatch_hook(cls, _pkt=None, *args, **kargs):
842        if 'tls_session' in kargs:
843            s = kargs['tls_session']
844            if s.server_tmp_rsa_key is None and s.server_rsa_key is None:
845                return _UnEncryptedPreMasterSecret
846        return EncryptedPreMasterSecret
847
848    def pre_dissect(self, m):
849        s = self.tls_session
850        tbd = m
851        if s.tls_version >= 0x0301:
852            if len(m) < 2:      # Should not happen
853                return m
854            l = struct.unpack("!H", m[:2])[0]
855            if len(m) != l+2:
856                err = "TLS 1.0+, but RSA Encrypted PMS with no explicit length"
857                warning(err)
858            else:
859                tbd = m[2:]
860        if s.server_tmp_rsa_key is not None:
861            # priority is given to the tmp_key, if there is one
862            decrypted = s.server_tmp_rsa_key.decrypt(tbd)
863            pms = decrypted[-48:]
864        elif s.server_rsa_key is not None:
865            decrypted = s.server_rsa_key.decrypt(tbd)
866            pms = decrypted[-48:]
867        else:
868            # the dispatch_hook is supposed to prevent this case
869            pms = b"\x00"*48
870            err = "No server RSA key to decrypt Pre Master Secret. Skipping."
871            warning(err)
872
873        s.pre_master_secret = pms
874        s.compute_ms_and_derive_keys()
875
876        return pms
877
878    def post_build(self, pkt, pay):
879        """
880        We encrypt the premaster secret (the 48 bytes) with either the server
881        certificate or the temporary RSA key provided in a server key exchange
882        message. After that step, we add the 2 bytes to provide the length, as
883        described in implementation notes at the end of section 7.4.7.1.
884        """
885        enc = pkt
886
887        s = self.tls_session
888        s.pre_master_secret = enc
889        s.compute_ms_and_derive_keys()
890
891        if s.server_tmp_rsa_key is not None:
892            enc = s.server_tmp_rsa_key.encrypt(pkt, t="pkcs")
893        elif s.server_certs is not None and len(s.server_certs) > 0:
894            enc = s.server_certs[0].encrypt(pkt, t="pkcs")
895        else:
896            warning("No material to encrypt Pre Master Secret")
897
898        l = b""
899        if s.tls_version >= 0x0301:
900            l = struct.pack("!H", len(enc))
901        return l + enc + pay
902
903    def guess_payload_class(self, p):
904        return Padding
905
906
907# Pre-Shared Key
908
909class ClientPSKIdentity(Packet):
910    """
911    XXX We provide parsing abilities for ServerPSKParams, but the context
912    operations have not been implemented yet. See RFC 4279.
913    Note that we do not cover the (EC)DHE_PSK nor the RSA_PSK key exchange,
914    which should contain either an EncryptedPMS or a ClientDiffieHellmanPublic.
915    """
916    name = "Server PSK parameters"
917    fields_desc = [ FieldLenField("psk_identity_len", None,
918                                  length_of="psk_identity", fmt="!H"),
919                    StrLenField("psk_identity", "",
920                        length_from=lambda pkt: pkt.psk_identity_len) ]
921
922