1## This file is part of Scapy
2## Copyright (C) 2007, 2008, 2009 Arnaud Ebalard
3##               2015, 2016, 2017 Maxence Tury
4## This program is published under a GPLv2 license
5
6"""
7The _TLSAutomaton class provides methods common to both TLS client and server.
8"""
9
10import struct
11
12from scapy.automaton import Automaton
13from scapy.error import log_interactive
14from scapy.packet import Raw
15from scapy.layers.tls.basefields import _tls_type
16from scapy.layers.tls.cert import Cert, PrivKey
17from scapy.layers.tls.record import TLS
18from scapy.layers.tls.record_sslv2 import SSLv2
19from scapy.layers.tls.record_tls13 import TLS13
20
21
22class _TLSAutomaton(Automaton):
23    """
24    SSLv3 and TLS 1.0-1.2 typically need a 2-RTT handshake:
25
26    Client        Server
27      | --------->>> |    C1 - ClientHello
28      | <<<--------- |    S1 - ServerHello
29      | <<<--------- |    S1 - Certificate
30      | <<<--------- |    S1 - ServerKeyExchange
31      | <<<--------- |    S1 - ServerHelloDone
32      | --------->>> |    C2 - ClientKeyExchange
33      | --------->>> |    C2 - ChangeCipherSpec
34      | --------->>> |    C2 - Finished [encrypted]
35      | <<<--------- |    S2 - ChangeCipherSpec
36      | <<<--------- |    S2 - Finished [encrypted]
37
38    We call these successive groups of messages:
39    ClientFlight1, ServerFlight1, ClientFlight2 and ServerFlight2.
40
41    We want to send our messages from the same flight all at once through the
42    socket. This is achieved by managing a list of records in 'buffer_out'.
43    We may put several messages (i.e. what RFC 5246 calls the record fragments)
44    in the same record when possible, but we may need several records for the
45    same flight, as with ClientFlight2.
46
47    However, note that the flights from the opposite side may be spread wildly
48    accross TLS records and TCP packets. This is why we use a 'get_next_msg'
49    method for feeding a list of received messages, 'buffer_in'. Raw data
50    which has not yet been interpreted as a TLS record is kept in 'remain_in'.
51    """
52    def parse_args(self, mycert=None, mykey=None, **kargs):
53
54        super(_TLSAutomaton, self).parse_args(**kargs)
55
56        self.socket = None
57        self.remain_in = b""
58        self.buffer_in = []         # these are 'fragments' inside records
59        self.buffer_out = []        # these are records
60
61        self.cur_session = None
62        self.cur_pkt = None         # this is usually the latest parsed packet
63
64        if mycert:
65            self.mycert = Cert(mycert)
66        else:
67            self.mycert = None
68
69        if mykey:
70            self.mykey = PrivKey(mykey)
71        else:
72            self.mykey = None
73
74        self.verbose = kargs.get("verbose", True)
75
76
77    def get_next_msg(self, socket_timeout=2, retry=2):
78        """
79        The purpose of the function is to make next message(s) available in
80        self.buffer_in. If the list is not empty, nothing is done. If not, in
81        order to fill it, the function uses the data already available in
82        self.remain_in from a previous call and waits till there are enough to
83        dissect a TLS packet. Once dissected, the content of the TLS packet
84        (carried messages, or 'fragments') is appended to self.buffer_in.
85
86        We have to grab enough data to dissect a TLS packet. We start by
87        reading the first 2 bytes. Unless we get anything different from
88        \\x14\\x03, \\x15\\x03, \\x16\\x03 or \\x17\\x03 (which might indicate
89        an SSLv2 record, whose first 2 bytes encode the length), we retrieve
90        3 more bytes in order to get the length of the TLS record, and
91        finally we can retrieve the remaining of the record.
92        """
93        if self.buffer_in:
94            # A message is already available.
95            return
96
97        self.socket.settimeout(socket_timeout)
98        is_sslv2_msg = False
99        still_getting_len = True
100        grablen = 2
101        while retry and (still_getting_len or len(self.remain_in) < grablen):
102            if not is_sslv2_msg and grablen == 5 and len(self.remain_in) >= 5:
103                grablen = struct.unpack('!H', self.remain_in[3:5])[0] + 5
104                still_getting_len = False
105            elif grablen == 2 and len(self.remain_in) >= 2:
106                byte0 = struct.unpack("B", self.remain_in[:1])[0]
107                byte1 = struct.unpack("B", self.remain_in[1:2])[0]
108                if (byte0 in _tls_type) and (byte1 == 3):
109                    # Retry following TLS scheme. This will cause failure
110                    # for SSLv2 packets with length 0x1{4-7}03.
111                    grablen = 5
112                else:
113                    # Extract the SSLv2 length.
114                    is_sslv2_msg = True
115                    still_getting_len = False
116                    if byte0 & 0x80:
117                        grablen = 2 + 0 + ((byte0 & 0x7f) << 8) + byte1
118                    else:
119                        grablen = 2 + 1 + ((byte0 & 0x3f) << 8) + byte1
120            elif not is_sslv2_msg and grablen == 5 and len(self.remain_in) >= 5:
121                grablen = struct.unpack('!H', self.remain_in[3:5])[0] + 5
122
123            if grablen == len(self.remain_in):
124                break
125
126            try:
127                tmp = self.socket.recv(grablen - len(self.remain_in))
128                if not tmp:
129                    retry -= 1
130                else:
131                    self.remain_in += tmp
132            except:
133                self.vprint("Could not join host ! Retrying...")
134                retry -= 1
135
136        if len(self.remain_in) < 2 or len(self.remain_in) != grablen:
137            # Remote peer is not willing to respond
138            return
139
140        p = TLS(self.remain_in, tls_session=self.cur_session)
141        self.cur_session = p.tls_session
142        self.remain_in = b""
143        if isinstance(p, SSLv2) and not p.msg:
144            p.msg = Raw("")
145        if self.cur_session.tls_version is None or \
146           self.cur_session.tls_version < 0x0304:
147            self.buffer_in += p.msg
148        else:
149            if isinstance(p, TLS13):
150                self.buffer_in += p.inner.msg
151            else:
152                # should be TLS13ServerHello only
153                self.buffer_in += p.msg
154
155        while p.payload:
156            if isinstance(p.payload, Raw):
157                self.remain_in += p.payload.load
158                p = p.payload
159            elif isinstance(p.payload, TLS):
160                p = p.payload
161                if self.cur_session.tls_version is None or \
162                   self.cur_session.tls_version < 0x0304:
163                    self.buffer_in += p.msg
164                else:
165                    self.buffer_in += p.inner.msg
166
167    def raise_on_packet(self, pkt_cls, state, get_next_msg=True):
168        """
169        If the next message to be processed has type 'pkt_cls', raise 'state'.
170        If there is no message waiting to be processed, we try to get one with
171        the default 'get_next_msg' parameters.
172        """
173        # Maybe we already parsed the expected packet, maybe not.
174        if get_next_msg:
175            self.get_next_msg()
176        if (not self.buffer_in or
177            not isinstance(self.buffer_in[0], pkt_cls)):
178            return
179        self.cur_pkt = self.buffer_in[0]
180        self.buffer_in = self.buffer_in[1:]
181        raise state()
182
183    def add_record(self, is_sslv2=None, is_tls13=None):
184        """
185        Add a new TLS or SSLv2 or TLS 1.3 record to the packets buffered out.
186        """
187        if is_sslv2 is None and is_tls13 is None:
188            v = (self.cur_session.tls_version or
189                 self.cur_session.advertised_tls_version)
190            if v in [0x0200, 0x0002]:
191                is_sslv2 = True
192            elif v >= 0x0304:
193                is_tls13 = True
194        if is_sslv2:
195            self.buffer_out.append(SSLv2(tls_session=self.cur_session))
196        elif is_tls13:
197            self.buffer_out.append(TLS13(tls_session=self.cur_session))
198        else:
199            self.buffer_out.append(TLS(tls_session=self.cur_session))
200
201    def add_msg(self, pkt):
202        """
203        Add a TLS message (e.g. TLSClientHello or TLSApplicationData)
204        inside the latest record to be sent through the socket.
205        We believe a good automaton should not use the first test.
206        """
207        if not self.buffer_out:
208            self.add_record()
209        r = self.buffer_out[-1]
210        if isinstance(r, TLS13):
211            self.buffer_out[-1].inner.msg.append(pkt)
212        else:
213            self.buffer_out[-1].msg.append(pkt)
214
215    def flush_records(self):
216        """
217        Send all buffered records and update the session accordingly.
218        """
219        s = b"".join(p.raw_stateful() for p in self.buffer_out)
220        self.socket.send(s)
221        self.buffer_out = []
222
223    def vprint(self, s=""):
224        if self.verbose:
225            log_interactive.info("> %s", s)
226
227