1## This file is part of Scapy
2## Copyright (C) 2017 Maxence Tury
3## This program is published under a GPLv2 license
4
5"""
6TLS 1.3 key exchange logic.
7"""
8
9import math
10
11from scapy.config import conf, crypto_validator
12from scapy.error import log_runtime, warning
13from scapy.fields import *
14from scapy.packet import Packet, Raw, Padding
15from scapy.layers.tls.cert import PubKeyRSA, PrivKeyRSA
16from scapy.layers.tls.session import _GenericTLSSessionInheritance
17from scapy.layers.tls.basefields import _tls_version, _TLSClientVersionField
18from scapy.layers.tls.extensions import TLS_Ext_Unknown, _tls_ext
19from scapy.layers.tls.crypto.pkcs1 import pkcs_i2osp, pkcs_os2ip
20from scapy.layers.tls.crypto.groups import (_tls_named_ffdh_groups,
21                                            _tls_named_curves, _ffdh_groups,
22                                            _tls_named_groups)
23
24if conf.crypto_valid:
25    from cryptography.hazmat.backends import default_backend
26    from cryptography.hazmat.primitives.asymmetric import dh, ec
27if conf.crypto_valid_advanced:
28    from cryptography.hazmat.primitives.asymmetric import x25519
29
30
31class KeyShareEntry(Packet):
32    """
33    When building from scratch, we create a DH private key, and when
34    dissecting, we create a DH public key. Default group is secp256r1.
35    """
36    __slots__ = ["privkey", "pubkey"]
37    name = "Key Share Entry"
38    fields_desc = [ShortEnumField("group", None, _tls_named_groups),
39                   FieldLenField("kxlen", None, length_of="key_exchange"),
40                   StrLenField("key_exchange", "",
41                               length_from=lambda pkt: pkt.kxlen) ]
42
43    def __init__(self, *args, **kargs):
44        self.privkey = None
45        self.pubkey = None
46        super(KeyShareEntry, self).__init__(*args, **kargs)
47
48    def do_build(self):
49        """
50        We need this hack, else 'self' would be replaced by __iter__.next().
51        """
52        tmp = self.explicit
53        self.explicit = True
54        b = super(KeyShareEntry, self).do_build()
55        self.explicit = tmp
56        return b
57
58    @crypto_validator
59    def create_privkey(self):
60        """
61        This is called by post_build() for key creation.
62        """
63        if self.group in _tls_named_ffdh_groups:
64            params = _ffdh_groups[_tls_named_ffdh_groups[self.group]][0]
65            privkey = params.generate_private_key()
66            self.privkey = privkey
67            pubkey = privkey.public_key()
68            self.key_exchange = pubkey.public_numbers().y
69        elif self.group in _tls_named_curves:
70            if _tls_named_curves[self.group] == "x25519":
71                if conf.crypto_valid_advanced:
72                    privkey = x25519.X25519PrivateKey.generate()
73                    self.privkey = privkey
74                    pubkey = privkey.public_key()
75                    self.key_exchange = pubkey.public_bytes()
76            elif _tls_named_curves[self.group] != "x448":
77                curve = ec._CURVE_TYPES[_tls_named_curves[self.group]]()
78                privkey = ec.generate_private_key(curve, default_backend())
79                self.privkey = privkey
80                pubkey = privkey.public_key()
81                self.key_exchange = pubkey.public_numbers().encode_point()
82
83    def post_build(self, pkt, pay):
84        if self.group is None:
85            self.group = 23     # secp256r1
86
87        if not self.key_exchange:
88            try:
89                self.create_privkey()
90            except ImportError:
91                pass
92
93        if self.kxlen is None:
94            self.kxlen = len(self.key_exchange)
95
96        group = struct.pack("!H", self.group)
97        kxlen = struct.pack("!H", self.kxlen)
98        return group + kxlen + self.key_exchange + pay
99
100    @crypto_validator
101    def register_pubkey(self):
102        if self.group in _tls_named_ffdh_groups:
103            params = _ffdh_groups[_tls_named_ffdh_groups[self.group]][0]
104            pn = params.parameter_numbers()
105            public_numbers = dh.DHPublicNumbers(self.key_exchange, pn)
106            self.pubkey = public_numbers.public_key(default_backend())
107        elif self.group in _tls_named_curves:
108            if _tls_named_curves[self.group] == "x25519":
109                if conf.crypto_valid_advanced:
110                    import_point = x25519.X25519PublicKey.from_public_bytes
111                    self.pubkey = import_point(self.key_exchange)
112            elif _tls_named_curves[self.group] != "x448":
113                curve = ec._CURVE_TYPES[_tls_named_curves[self.group]]()
114                import_point = ec.EllipticCurvePublicNumbers.from_encoded_point
115                public_numbers = import_point(curve, self.key_exchange)
116                self.pubkey = public_numbers.public_key(default_backend())
117
118    def post_dissection(self, r):
119        try:
120            self.register_pubkey()
121        except ImportError:
122            pass
123
124    def extract_padding(self, s):
125        return "", s
126
127
128class TLS_Ext_KeyShare_CH(TLS_Ext_Unknown):
129    name = "TLS Extension - Key Share (for ClientHello)"
130    fields_desc = [ShortEnumField("type", 0x28, _tls_ext),
131                   ShortField("len", None),
132                   FieldLenField("client_shares_len", None,
133                                 length_of="client_shares"),
134                   PacketListField("client_shares", [], KeyShareEntry,
135                            length_from=lambda pkt: pkt.client_shares_len) ]
136
137    def post_build(self, pkt, pay):
138        if not self.tls_session.frozen:
139            privshares = self.tls_session.tls13_client_privshares
140            for kse in self.client_shares:
141                if kse.privkey:
142                    if _tls_named_curves[kse.group] in privshares:
143                        pkt_info = pkt.firstlayer().summary()
144                        log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info)
145                        break
146                    privshares[_tls_named_groups[kse.group]] = kse.privkey
147        return super(TLS_Ext_KeyShare_CH, self).post_build(pkt, pay)
148
149    def post_dissection(self, r):
150        if not self.tls_session.frozen:
151            for kse in self.client_shares:
152                if kse.pubkey:
153                    pubshares = self.tls_session.tls13_client_pubshares
154                    if _tls_named_curves[kse.group] in pubshares:
155                        pkt_info = r.firstlayer().summary()
156                        log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info)
157                        break
158                    pubshares[_tls_named_curves[kse.group]] = kse.pubkey
159        return super(TLS_Ext_KeyShare_CH, self).post_dissection(r)
160
161
162class TLS_Ext_KeyShare_HRR(TLS_Ext_Unknown):
163    name = "TLS Extension - Key Share (for HelloRetryRequest)"
164    fields_desc = [ShortEnumField("type", 0x28, _tls_ext),
165                   ShortField("len", None),
166                   ShortEnumField("selected_group", None, _tls_named_groups) ]
167
168
169class TLS_Ext_KeyShare_SH(TLS_Ext_Unknown):
170    name = "TLS Extension - Key Share (for ServerHello)"
171    fields_desc = [ShortEnumField("type", 0x28, _tls_ext),
172                   ShortField("len", None),
173                   PacketField("server_share", None, KeyShareEntry) ]
174
175    def post_build(self, pkt, pay):
176        if not self.tls_session.frozen and self.server_share.privkey:
177            # if there is a privkey, we assume the crypto library is ok
178            privshare = self.tls_session.tls13_server_privshare
179            if len(privshare) > 0:
180                pkt_info = pkt.firstlayer().summary()
181                log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info)
182            group_name = _tls_named_groups[self.server_share.group]
183            privshare[group_name] = self.server_share.privkey
184
185            if group_name in self.tls_session.tls13_client_pubshares:
186                privkey = self.server_share.privkey
187                pubkey = self.tls_session.tls13_client_pubshares[group_name]
188                if group_name in six.itervalues(_tls_named_ffdh_groups):
189                    pms = privkey.exchange(pubkey)
190                elif group_name in six.itervalues(_tls_named_curves):
191                    if group_name == "x25519":
192                        pms = privkey.exchange(pubkey)
193                    else:
194                        pms = privkey.exchange(ec.ECDH(), pubkey)
195                self.tls_session.tls13_dhe_secret = pms
196        return super(TLS_Ext_KeyShare_SH, self).post_build(pkt, pay)
197
198    def post_dissection(self, r):
199        if not self.tls_session.frozen and self.server_share.pubkey:
200            # if there is a pubkey, we assume the crypto library is ok
201            pubshare = self.tls_session.tls13_server_pubshare
202            if len(pubshare) > 0:
203                pkt_info = r.firstlayer().summary()
204                log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info)
205            group_name = _tls_named_groups[self.server_share.group]
206            pubshare[group_name] = self.server_share.pubkey
207
208            if group_name in self.tls_session.tls13_client_privshares:
209                pubkey = self.server_share.pubkey
210                privkey = self.tls_session.tls13_client_privshares[group_name]
211                if group_name in six.itervalues(_tls_named_ffdh_groups):
212                    pms = privkey.exchange(pubkey)
213                elif group_name in six.itervalues(_tls_named_curves):
214                    if group_name == "x25519":
215                        pms = privkey.exchange(pubkey)
216                    else:
217                        pms = privkey.exchange(ec.ECDH(), pubkey)
218                self.tls_session.tls13_dhe_secret = pms
219        return super(TLS_Ext_KeyShare_SH, self).post_dissection(r)
220
221
222_tls_ext_keyshare_cls  = { 1: TLS_Ext_KeyShare_CH,
223                           2: TLS_Ext_KeyShare_SH,
224                           6: TLS_Ext_KeyShare_HRR }
225
226
227class Ticket(Packet):
228    name = "Recommended Ticket Construction (from RFC 5077)"
229    fields_desc = [ StrFixedLenField("key_name", None, 16),
230                    StrFixedLenField("iv", None, 16),
231                    FieldLenField("encstatelen", None, length_of="encstate"),
232                    StrLenField("encstate", "",
233                                length_from=lambda pkt: pkt.encstatelen),
234                    StrFixedLenField("mac", None, 32) ]
235
236class TicketField(PacketField):
237    __slots__ = ["length_from"]
238    def __init__(self, name, default, length_from=None, **kargs):
239        self.length_from = length_from
240        PacketField.__init__(self, name, default, Ticket, **kargs)
241
242    def m2i(self, pkt, m):
243        l = self.length_from(pkt)
244        tbd, rem = m[:l], m[l:]
245        return self.cls(tbd)/Padding(rem)
246
247class PSKIdentity(Packet):
248    name = "PSK Identity"
249    fields_desc = [FieldLenField("identity_len", None,
250                                 length_of="identity"),
251                   TicketField("identity", "",
252                               length_from=lambda pkt: pkt.identity_len),
253                   IntField("obfuscated_ticket_age", 0) ]
254
255class PSKBinderEntry(Packet):
256    name = "PSK Binder Entry"
257    fields_desc = [FieldLenField("binder_len", None, fmt="B",
258                                 length_of="binder"),
259                   StrLenField("binder", "",
260                               length_from=lambda pkt: pkt.binder_len) ]
261
262class TLS_Ext_PreSharedKey_CH(TLS_Ext_Unknown):
263    #XXX define post_build and post_dissection methods
264    name = "TLS Extension - Pre Shared Key (for ClientHello)"
265    fields_desc = [ShortEnumField("type", 0x28, _tls_ext),
266                   ShortField("len", None),
267                   FieldLenField("identities_len", None,
268                                 length_of="identities"),
269                   PacketListField("identities", [], PSKIdentity,
270                            length_from=lambda pkt: pkt.identities_len),
271                   FieldLenField("binders_len", None,
272                                 length_of="binders"),
273                   PacketListField("binders", [], PSKBinderEntry,
274                            length_from=lambda pkt: pkt.binders_len) ]
275
276
277class TLS_Ext_PreSharedKey_SH(TLS_Ext_Unknown):
278    name = "TLS Extension - Pre Shared Key (for ServerHello)"
279    fields_desc = [ShortEnumField("type", 0x29, _tls_ext),
280                   ShortField("len", None),
281                   ShortField("selected_identity", None) ]
282
283
284_tls_ext_presharedkey_cls  = { 1: TLS_Ext_PreSharedKey_CH,
285                               2: TLS_Ext_PreSharedKey_SH }
286
287