1 /*
2  * Copyright (C) 2014 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.security.cts;
18 
19 import android.content.Context;
20 import android.test.InstrumentationTestCase;
21 import android.util.Log;
22 
23 import com.android.cts.security.R;
24 
25 import java.io.ByteArrayInputStream;
26 import java.io.ByteArrayOutputStream;
27 import java.io.EOFException;
28 import java.io.IOException;
29 import java.io.InputStream;
30 import java.io.OutputStream;
31 import java.net.ServerSocket;
32 import java.net.Socket;
33 import java.net.SocketAddress;
34 import java.security.KeyFactory;
35 import java.security.Principal;
36 import java.security.PrivateKey;
37 import java.security.cert.CertificateException;
38 import java.security.cert.CertificateFactory;
39 import java.security.cert.X509Certificate;
40 import java.security.spec.PKCS8EncodedKeySpec;
41 import java.util.concurrent.Callable;
42 import java.util.concurrent.ExecutionException;
43 import java.util.concurrent.ExecutorService;
44 import java.util.concurrent.Executors;
45 import java.util.concurrent.Future;
46 import java.util.concurrent.TimeUnit;
47 
48 import javax.net.ServerSocketFactory;
49 import javax.net.SocketFactory;
50 import javax.net.ssl.KeyManager;
51 import javax.net.ssl.SSLContext;
52 import javax.net.ssl.SSLException;
53 import javax.net.ssl.SSLServerSocket;
54 import javax.net.ssl.SSLSocket;
55 import javax.net.ssl.TrustManager;
56 import javax.net.ssl.X509KeyManager;
57 import javax.net.ssl.X509TrustManager;
58 
59 /**
60  * Tests for the OpenSSL Heartbleed vulnerability.
61  */
62 public class OpenSSLHeartbleedTest extends InstrumentationTestCase {
63 
64     // IMPLEMENTATION NOTE: This test spawns an SSLSocket client, SSLServerSocket server, and a
65     // Man-in-The-Middle (MiTM). The client connects to the MiTM which then connects to the server
66     // and starts forwarding all TLS records between the client and the server. In tests that check
67     // for the Heartbleed vulnerability, the MiTM also injects a HeartbeatRequest message into the
68     // traffic.
69 
70     // IMPLEMENTATION NOTE: This test spawns several background threads that perform network I/O
71     // on localhost. To ensure that these background threads are cleaned up at the end of the test
72     // tearDown() kills the sockets they may be using. To aid this behavior, all Socket and
73     // ServerSocket instances are available as fields of this class. These fields should be accessed
74     // via setters and getters to avoid memory visibility issues due to concurrency.
75 
76     private static final String TAG = OpenSSLHeartbleedTest.class.getSimpleName();
77 
78     private SSLServerSocket mServerListeningSocket;
79     private SSLSocket mServerSocket;
80     private SSLSocket mClientSocket;
81     private ServerSocket mMitmListeningSocket;
82     private Socket mMitmServerSocket;
83     private Socket mMitmClientSocket;
84     private ExecutorService mExecutorService;
85 
86     private boolean mHeartbeatRequestWasInjected;
87     private boolean mHeartbeatResponseWasDetetected;
88     private int mFirstDetectedFatalAlertDescription = -1;
89 
90     @Override
tearDown()91     protected void tearDown() throws Exception {
92         Log.i(TAG, "Tearing down");
93         if (mExecutorService != null) {
94             mExecutorService.shutdownNow();
95         }
96         closeQuietly(getServerListeningSocket());
97         closeQuietly(getServerSocket());
98         closeQuietly(getClientSocket());
99         closeQuietly(getMitmListeningSocket());
100         closeQuietly(getMitmServerSocket());
101         closeQuietly(getMitmClientSocket());
102         super.tearDown();
103         Log.i(TAG, "Tear down completed");
104     }
105 
106     /**
107      * Tests that TLS handshake succeeds when the MiTM simply forwards all data without tampering
108      * with it. This is to catch issues unrelated to TLS heartbeats.
109      */
testWithoutHeartbeats()110     public void testWithoutHeartbeats() throws Exception {
111         handshake(false, false);
112     }
113 
114     /**
115      * Tests whether client sockets are vulnerable to Heartbleed.
116      */
testClientHeartbleed()117     public void testClientHeartbleed() throws Exception {
118         checkHeartbleed(true);
119     }
120 
121     /**
122      * Tests whether server sockets are vulnerable to Heartbleed.
123      */
testServerHeartbleed()124     public void testServerHeartbleed() throws Exception {
125         checkHeartbleed(false);
126     }
127 
128     /**
129      * Tests for Heartbleed.
130      *
131      * @param client {@code true} to test the client, {@code false} to test the server.
132      */
checkHeartbleed(boolean client)133     private void checkHeartbleed(boolean client) throws Exception {
134         // IMPLEMENTATION NOTE: The MiTM is forwarding all TLS records between the client and the
135         // server unmodified. Additionally, the MiTM transmits a malformed HeartbeatRequest to
136         // server (if "client" argument is false) right after client's ClientKeyExchange or to
137         // client (if "client" argument is true) right after server's ServerHello. The peer is
138         // expected to either ignore the HeartbeatRequest (if heartbeats are supported) or to abort
139         // the handshake with unexpected_message alert (if heartbeats are not supported).
140         try {
141             handshake(true, client);
142         } catch (ExecutionException e) {
143             assertFalse(
144                     "SSLSocket is vulnerable to Heartbleed in " + ((client) ? "client" : "server")
145                             + " mode",
146                     wasHeartbeatResponseDetected());
147             if (e.getCause() instanceof SSLException) {
148                 // TLS handshake or data exchange failed. Check whether the error was caused by
149                 // fatal alert unexpected_message
150                 int alertDescription = getFirstDetectedFatalAlertDescription();
151                 if (alertDescription == -1) {
152                     fail("Handshake failed without a fatal alert");
153                 }
154                 assertEquals(
155                         "First fatal alert description received from server",
156                         AlertMessage.DESCRIPTION_UNEXPECTED_MESSAGE,
157                         alertDescription);
158                 return;
159             } else {
160                 throw e;
161             }
162         }
163 
164         // TLS handshake succeeded
165         assertFalse(
166                 "SSLSocket is vulnerable to Heartbleed in " + ((client) ? "client" : "server")
167                         + " mode",
168                 wasHeartbeatResponseDetected());
169         assertTrue("HeartbeatRequest not injected", wasHeartbeatRequestInjected());
170     }
171 
172     /**
173      * Starts the client, server, and the MiTM. Makes the client and server perform a TLS handshake
174      * and exchange application-level data. The MiTM injects a HeartbeatRequest message if requested
175      * by {@code heartbeatRequestInjected}. The direction of the injected message is specified by
176      * {@code injectedIntoClient}.
177      */
handshake( final boolean heartbeatRequestInjected, final boolean injectedIntoClient)178     private void handshake(
179             final boolean heartbeatRequestInjected,
180             final boolean injectedIntoClient) throws Exception {
181         mExecutorService = Executors.newFixedThreadPool(4);
182         setServerListeningSocket(serverBind());
183         final SocketAddress serverAddress = getServerListeningSocket().getLocalSocketAddress();
184         Log.i(TAG, "Server bound to " + serverAddress);
185 
186         setMitmListeningSocket(mitmBind());
187         final SocketAddress mitmAddress = getMitmListeningSocket().getLocalSocketAddress();
188         Log.i(TAG, "MiTM bound to " + mitmAddress);
189 
190         // Start the MiTM daemon in the background
191         mExecutorService.submit(new Callable<Void>() {
192             @Override
193             public Void call() throws Exception {
194                 mitmAcceptAndForward(
195                         serverAddress,
196                         heartbeatRequestInjected,
197                         injectedIntoClient);
198                 return null;
199             }
200         });
201         // Start the server in the background
202         Future<Void> serverFuture = mExecutorService.submit(new Callable<Void>() {
203             @Override
204             public Void call() throws Exception {
205                 serverAcceptAndHandshake();
206                 return null;
207             }
208         });
209         // Start the client in the background
210         Future<Void> clientFuture = mExecutorService.submit(new Callable<Void>() {
211             @Override
212             public Void call() throws Exception {
213                 clientConnectAndHandshake(mitmAddress);
214                 return null;
215             }
216         });
217 
218         // Wait for both client and server to terminate, to ensure that we observe all the traffic
219         // exchanged between them. Throw an exception if one of them failed.
220         Log.i(TAG, "Waiting for client");
221         // Wait for the client, but don't yet throw an exception if it failed.
222         Exception clientException = null;
223         try {
224             clientFuture.get(10, TimeUnit.SECONDS);
225         } catch (Exception e) {
226             clientException = e;
227         }
228         Log.i(TAG, "Waiting for server");
229         // Wait for the server and throw an exception if it failed.
230         serverFuture.get(5, TimeUnit.SECONDS);
231         // Throw an exception if the client failed.
232         if (clientException != null) {
233             throw clientException;
234         }
235         Log.i(TAG, "Handshake completed and application data exchanged");
236     }
237 
clientConnectAndHandshake(SocketAddress serverAddress)238     private void clientConnectAndHandshake(SocketAddress serverAddress) throws Exception {
239         SSLContext sslContext = SSLContext.getInstance("TLS");
240         sslContext.init(
241                 null,
242                 new TrustManager[] {new TrustAllX509TrustManager()},
243                 null);
244         SSLSocket socket = (SSLSocket) sslContext.getSocketFactory().createSocket();
245         setClientSocket(socket);
246         try {
247             Log.i(TAG, "Client connecting to " + serverAddress);
248             socket.connect(serverAddress);
249             Log.i(TAG, "Client connected to server from " + socket.getLocalSocketAddress());
250             // Ensure a TLS handshake is performed and an exception is thrown if it fails.
251             socket.getOutputStream().write("client".getBytes());
252             socket.getOutputStream().flush();
253             Log.i(TAG, "Client sent request. Reading response");
254             int b = socket.getInputStream().read();
255             Log.i(TAG, "Client read response: " + b);
256         } catch (Exception e) {
257             Log.w(TAG, "Client failed", e);
258             throw e;
259           } finally {
260             socket.close();
261         }
262     }
263 
serverBind()264     public SSLServerSocket serverBind() throws Exception {
265         // Load the server's private key and cert chain
266         KeyFactory keyFactory = KeyFactory.getInstance("RSA");
267         PrivateKey privateKey = keyFactory.generatePrivate(new PKCS8EncodedKeySpec(
268                 readResource(
269                         getInstrumentation().getContext(), R.raw.openssl_heartbleed_test_key)));
270         CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
271         X509Certificate[] certChain =  new X509Certificate[] {
272                 (X509Certificate) certFactory.generateCertificate(
273                         new ByteArrayInputStream(readResource(
274                                 getInstrumentation().getContext(),
275                                 R.raw.openssl_heartbleed_test_cert)))
276         };
277 
278         // Initialize TLS context to use the private key and cert chain for server sockets
279         SSLContext sslContext = SSLContext.getInstance("TLS");
280         sslContext.init(
281                 new KeyManager[] {new HardcodedCertX509KeyManager(privateKey, certChain)},
282                 null,
283                 null);
284 
285         Log.i(TAG, "Server binding to local port");
286         return (SSLServerSocket) sslContext.getServerSocketFactory().createServerSocket(0);
287     }
288 
serverAcceptAndHandshake()289     private void serverAcceptAndHandshake() throws Exception {
290         SSLSocket socket = null;
291         SSLServerSocket serverSocket = getServerListeningSocket();
292         try {
293             Log.i(TAG, "Server listening for incoming connection");
294             socket = (SSLSocket) serverSocket.accept();
295             setServerSocket(socket);
296             Log.i(TAG, "Server accepted connection from " + socket.getRemoteSocketAddress());
297             // Ensure a TLS handshake is performed and an exception is thrown if it fails.
298             socket.getOutputStream().write("server".getBytes());
299             socket.getOutputStream().flush();
300             Log.i(TAG, "Server sent reply. Reading response");
301             int b = socket.getInputStream().read();
302             Log.i(TAG, "Server read response: " + b);
303         } catch (Exception e) {
304           Log.w(TAG, "Server failed", e);
305           throw e;
306         } finally {
307             if (socket != null) {
308                 socket.close();
309             }
310         }
311     }
312 
mitmBind()313     private ServerSocket mitmBind() throws Exception {
314         Log.i(TAG, "MiTM binding to local port");
315         return ServerSocketFactory.getDefault().createServerSocket(0);
316     }
317 
318     /**
319      * Accepts the connection on the MiTM listening socket, forwards the TLS records between the
320      * client and the server, and, if requested, injects a {@code HeartbeatRequest}.
321      *
322      * @param injectHeartbeat whether to inject a {@code HeartbeatRequest} message.
323      * @param injectIntoClient when {@code injectHeartbeat} is {@code true}, whether to inject the
324      *        {@code HeartbeatRequest} message into client or into server.
325      */
mitmAcceptAndForward( SocketAddress serverAddress, final boolean injectHeartbeat, final boolean injectIntoClient)326     private void mitmAcceptAndForward(
327             SocketAddress serverAddress,
328             final boolean injectHeartbeat,
329             final boolean injectIntoClient) throws Exception {
330         Socket clientSocket = null;
331         Socket serverSocket = null;
332         ServerSocket listeningSocket = getMitmListeningSocket();
333         try {
334             Log.i(TAG, "MiTM waiting for incoming connection");
335             clientSocket = listeningSocket.accept();
336             setMitmClientSocket(clientSocket);
337             Log.i(TAG, "MiTM accepted connection from " + clientSocket.getRemoteSocketAddress());
338             serverSocket = SocketFactory.getDefault().createSocket();
339             setMitmServerSocket(serverSocket);
340             Log.i(TAG, "MiTM connecting to server " + serverAddress);
341             serverSocket.connect(serverAddress, 10000);
342             Log.i(TAG, "MiTM connected to server from " + serverSocket.getLocalSocketAddress());
343             final InputStream serverInputStream = serverSocket.getInputStream();
344             final OutputStream clientOutputStream = clientSocket.getOutputStream();
345             Future<Void> serverToClientTask = mExecutorService.submit(new Callable<Void>() {
346                 @Override
347                 public Void call() throws Exception {
348                     // Inject HeatbeatRequest after ServerHello, if requested
349                     forwardTlsRecords(
350                             "MiTM S->C",
351                             serverInputStream,
352                             clientOutputStream,
353                             (injectHeartbeat && injectIntoClient)
354                                     ? HandshakeMessage.TYPE_SERVER_HELLO : -1);
355                     return null;
356                 }
357             });
358             // Inject HeatbeatRequest after ClientKeyExchange, if requested
359             forwardTlsRecords(
360                     "MiTM C->S",
361                     clientSocket.getInputStream(),
362                     serverSocket.getOutputStream(),
363                     (injectHeartbeat && !injectIntoClient)
364                             ? HandshakeMessage.TYPE_CLIENT_KEY_EXCHANGE : -1);
365             serverToClientTask.get(10, TimeUnit.SECONDS);
366         } catch (Exception e) {
367             Log.w(TAG, "MiTM failed", e);
368             throw e;
369           } finally {
370             closeQuietly(clientSocket);
371             closeQuietly(serverSocket);
372         }
373     }
374 
375     /**
376      * Forwards TLS records from the provided {@code InputStream} to the provided
377      * {@code OutputStream}. If requested, injects a {@code HeartbeatMessage}.
378      */
forwardTlsRecords( String logPrefix, InputStream in, OutputStream out, int handshakeMessageTypeAfterWhichToInjectHeartbeatRequest)379     private void forwardTlsRecords(
380             String logPrefix,
381             InputStream in,
382             OutputStream out,
383             int handshakeMessageTypeAfterWhichToInjectHeartbeatRequest) throws Exception {
384         Log.i(TAG, logPrefix + ": record forwarding started");
385         boolean interestingRecordsLogged =
386                 handshakeMessageTypeAfterWhichToInjectHeartbeatRequest == -1;
387         try {
388             TlsRecordReader reader = new TlsRecordReader(in);
389             byte[] recordBytes;
390             // Fragments contained in records may be encrypted after a certain point in the
391             // handshake. Once they are encrypted, this MiTM cannot inspect their plaintext which.
392             boolean fragmentEncryptionMayBeEnabled = false;
393             while ((recordBytes = reader.readRecord()) != null) {
394                 TlsRecord record = TlsRecord.parse(recordBytes);
395                 forwardTlsRecord(logPrefix,
396                         recordBytes,
397                         record,
398                         fragmentEncryptionMayBeEnabled,
399                         out,
400                         interestingRecordsLogged,
401                         handshakeMessageTypeAfterWhichToInjectHeartbeatRequest);
402                 if (record.protocol == TlsProtocols.CHANGE_CIPHER_SPEC) {
403                     fragmentEncryptionMayBeEnabled = true;
404                 }
405             }
406         } catch (Exception e) {
407             Log.w(TAG, logPrefix + ": failed", e);
408             throw e;
409         } finally {
410             Log.d(TAG, logPrefix + ": record forwarding finished");
411         }
412     }
413 
forwardTlsRecord( String logPrefix, byte[] recordBytes, TlsRecord record, boolean fragmentEncryptionMayBeEnabled, OutputStream out, boolean interestingRecordsLogged, int handshakeMessageTypeAfterWhichToInjectHeartbeatRequest)414     private void forwardTlsRecord(
415             String logPrefix,
416             byte[] recordBytes,
417             TlsRecord record,
418             boolean fragmentEncryptionMayBeEnabled,
419             OutputStream out,
420             boolean interestingRecordsLogged,
421             int handshakeMessageTypeAfterWhichToInjectHeartbeatRequest) throws IOException {
422         // Save information about the records if its of interest to this test
423         if (interestingRecordsLogged) {
424             switch (record.protocol) {
425                 case TlsProtocols.ALERT:
426                     if (!fragmentEncryptionMayBeEnabled) {
427                         AlertMessage alert = AlertMessage.tryParse(record);
428                         if ((alert != null) && (alert.level == AlertMessage.LEVEL_FATAL)) {
429                             setFatalAlertDetected(alert.description);
430                         }
431                     }
432                     break;
433                 case TlsProtocols.HEARTBEAT:
434                     // When TLS records are encrypted, we cannot determine whether a
435                     // heartbeat is a HeartbeatResponse. In our setup, the client and the
436                     // server are not expected to sent HeartbeatRequests. Thus, we err on
437                     // the side of caution and assume that any heartbeat message sent by
438                     // client or server is a HeartbeatResponse.
439                     Log.e(TAG, logPrefix
440                             + ": heartbeat response detected -- vulnerable to Heartbleed");
441                     setHeartbeatResponseWasDetected();
442                     break;
443             }
444         }
445 
446         Log.i(TAG, logPrefix + ": Forwarding TLS record. "
447                 + getRecordInfo(record, fragmentEncryptionMayBeEnabled));
448         out.write(recordBytes);
449         out.flush();
450 
451         // Inject HeartbeatRequest, if necessary, after the specified handshake message type
452         if (handshakeMessageTypeAfterWhichToInjectHeartbeatRequest != -1) {
453             if ((!fragmentEncryptionMayBeEnabled) && (isHandshakeMessageType(
454                     record, handshakeMessageTypeAfterWhichToInjectHeartbeatRequest))) {
455                 // The Heartbeat Request message below is malformed because its declared
456                 // length of payload one byte larger than the actual payload. The peer is
457                 // supposed to reject such messages.
458                 byte[] payload = "arbitrary".getBytes("US-ASCII");
459                 byte[] heartbeatRequestRecordBytes = createHeartbeatRequestRecord(
460                         record.versionMajor,
461                         record.versionMinor,
462                         payload.length + 1,
463                         payload);
464                 Log.i(TAG, logPrefix + ": Injecting malformed HeartbeatRequest: "
465                         + getRecordInfo(
466                                 TlsRecord.parse(heartbeatRequestRecordBytes), false));
467                 setHeartbeatRequestWasInjected();
468                 out.write(heartbeatRequestRecordBytes);
469                 out.flush();
470             }
471         }
472     }
473 
getRecordInfo(TlsRecord record, boolean mayBeEncrypted)474     private static String getRecordInfo(TlsRecord record, boolean mayBeEncrypted) {
475         StringBuilder result = new StringBuilder();
476         result.append(getProtocolName(record.protocol))
477                 .append(", ")
478                 .append(getFragmentInfo(record, mayBeEncrypted));
479         return result.toString();
480     }
481 
getProtocolName(int protocol)482     private static String getProtocolName(int protocol) {
483         switch (protocol) {
484             case TlsProtocols.ALERT:
485                 return "alert";
486             case TlsProtocols.APPLICATION_DATA:
487                 return "application data";
488             case TlsProtocols.CHANGE_CIPHER_SPEC:
489                 return "change cipher spec";
490             case TlsProtocols.HANDSHAKE:
491                 return "handshake";
492             case TlsProtocols.HEARTBEAT:
493                 return "heatbeat";
494             default:
495                 return String.valueOf(protocol);
496         }
497     }
498 
getFragmentInfo(TlsRecord record, boolean mayBeEncrypted)499     private static String getFragmentInfo(TlsRecord record, boolean mayBeEncrypted) {
500         StringBuilder result = new StringBuilder();
501         if (mayBeEncrypted) {
502             result.append("encrypted?");
503         } else {
504             switch (record.protocol) {
505                 case TlsProtocols.ALERT:
506                     result.append("level: " + ((record.fragment.length > 0)
507                             ? String.valueOf(record.fragment[0] & 0xff) : "n/a")
508                     + ", description: "
509                     + ((record.fragment.length > 1)
510                             ? String.valueOf(record.fragment[1] & 0xff) : "n/a"));
511                     break;
512                 case TlsProtocols.APPLICATION_DATA:
513                     break;
514                 case TlsProtocols.CHANGE_CIPHER_SPEC:
515                     result.append("payload: " + ((record.fragment.length > 0)
516                             ? String.valueOf(record.fragment[0] & 0xff) : "n/a"));
517                     break;
518                 case TlsProtocols.HANDSHAKE:
519                     result.append("type: " + ((record.fragment.length > 0)
520                             ? String.valueOf(record.fragment[0] & 0xff) : "n/a"));
521                     break;
522                 case TlsProtocols.HEARTBEAT:
523                     result.append("type: " + ((record.fragment.length > 0)
524                             ? String.valueOf(record.fragment[0] & 0xff) : "n/a")
525                             + ", payload length: "
526                             + ((record.fragment.length >= 3)
527                                     ? String.valueOf(
528                                             getUnsignedShortBigEndian(record.fragment, 1))
529                                     : "n/a"));
530                     break;
531             }
532         }
533         result.append(", ").append("fragment length: " + record.fragment.length);
534         return result.toString();
535     }
536 
setServerListeningSocket(SSLServerSocket socket)537     private synchronized void setServerListeningSocket(SSLServerSocket socket) {
538         mServerListeningSocket = socket;
539     }
540 
getServerListeningSocket()541     private synchronized SSLServerSocket getServerListeningSocket() {
542         return mServerListeningSocket;
543     }
544 
setServerSocket(SSLSocket socket)545     private synchronized void setServerSocket(SSLSocket socket) {
546         mServerSocket = socket;
547     }
548 
getServerSocket()549     private synchronized SSLSocket getServerSocket() {
550         return mServerSocket;
551     }
552 
setClientSocket(SSLSocket socket)553     private synchronized void setClientSocket(SSLSocket socket) {
554         mClientSocket = socket;
555     }
556 
getClientSocket()557     private synchronized SSLSocket getClientSocket() {
558         return mClientSocket;
559     }
560 
setMitmListeningSocket(ServerSocket socket)561     private synchronized void setMitmListeningSocket(ServerSocket socket) {
562         mMitmListeningSocket = socket;
563     }
564 
getMitmListeningSocket()565     private synchronized ServerSocket getMitmListeningSocket() {
566         return mMitmListeningSocket;
567     }
568 
setMitmServerSocket(Socket socket)569     private synchronized void setMitmServerSocket(Socket socket) {
570         mMitmServerSocket = socket;
571     }
572 
getMitmServerSocket()573     private synchronized Socket getMitmServerSocket() {
574         return mMitmServerSocket;
575     }
576 
setMitmClientSocket(Socket socket)577     private synchronized void setMitmClientSocket(Socket socket) {
578         mMitmClientSocket = socket;
579     }
580 
getMitmClientSocket()581     private synchronized Socket getMitmClientSocket() {
582         return mMitmClientSocket;
583     }
584 
setHeartbeatRequestWasInjected()585     private synchronized void setHeartbeatRequestWasInjected() {
586         mHeartbeatRequestWasInjected = true;
587     }
588 
wasHeartbeatRequestInjected()589     private synchronized boolean wasHeartbeatRequestInjected() {
590         return mHeartbeatRequestWasInjected;
591     }
592 
setHeartbeatResponseWasDetected()593     private synchronized void setHeartbeatResponseWasDetected() {
594         mHeartbeatResponseWasDetetected = true;
595     }
596 
wasHeartbeatResponseDetected()597     private synchronized boolean wasHeartbeatResponseDetected() {
598         return mHeartbeatResponseWasDetetected;
599     }
600 
setFatalAlertDetected(int description)601     private synchronized void setFatalAlertDetected(int description) {
602         if (mFirstDetectedFatalAlertDescription == -1) {
603             mFirstDetectedFatalAlertDescription = description;
604         }
605     }
606 
getFirstDetectedFatalAlertDescription()607     private synchronized int getFirstDetectedFatalAlertDescription() {
608         return mFirstDetectedFatalAlertDescription;
609     }
610 
611     public static abstract class TlsProtocols {
612         public static final int CHANGE_CIPHER_SPEC = 20;
613         public static final int ALERT = 21;
614         public static final int HANDSHAKE = 22;
615         public static final int APPLICATION_DATA = 23;
616         public static final int HEARTBEAT = 24;
TlsProtocols()617         private TlsProtocols() {}
618     }
619 
620     public static class TlsRecord {
621         public int protocol;
622         public int versionMajor;
623         public int versionMinor;
624         public byte[] fragment;
625 
parse(byte[] record)626         public static TlsRecord parse(byte[] record) throws IOException {
627             TlsRecord result = new TlsRecord();
628             if (record.length < TlsRecordReader.RECORD_HEADER_LENGTH) {
629                 throw new IOException("Record too short: " + record.length);
630             }
631             result.protocol = record[0] & 0xff;
632             result.versionMajor = record[1] & 0xff;
633             result.versionMinor = record[2] & 0xff;
634             int fragmentLength = getUnsignedShortBigEndian(record, 3);
635             int actualFragmentLength = record.length - TlsRecordReader.RECORD_HEADER_LENGTH;
636             if (fragmentLength != actualFragmentLength) {
637                 throw new IOException("Fragment length mismatch. Expected: " + fragmentLength
638                         + ", actual: " + actualFragmentLength);
639             }
640             result.fragment = new byte[fragmentLength];
641             System.arraycopy(
642                     record, TlsRecordReader.RECORD_HEADER_LENGTH,
643                     result.fragment, 0,
644                     fragmentLength);
645             return result;
646         }
647 
unparse(TlsRecord record)648         public static byte[] unparse(TlsRecord record) {
649             byte[] result = new byte[TlsRecordReader.RECORD_HEADER_LENGTH + record.fragment.length];
650             result[0] = (byte) record.protocol;
651             result[1] = (byte) record.versionMajor;
652             result[2] = (byte) record.versionMinor;
653             putUnsignedShortBigEndian(result, 3, record.fragment.length);
654             System.arraycopy(
655                     record.fragment, 0,
656                     result, TlsRecordReader.RECORD_HEADER_LENGTH,
657                     record.fragment.length);
658             return result;
659         }
660     }
661 
isHandshakeMessageType(TlsRecord record, int type)662     public static final boolean isHandshakeMessageType(TlsRecord record, int type) {
663         HandshakeMessage handshake = HandshakeMessage.tryParse(record);
664         if (handshake == null) {
665             return false;
666         }
667         return handshake.type == type;
668     }
669 
670     public static class HandshakeMessage {
671         public static final int TYPE_SERVER_HELLO = 2;
672         public static final int TYPE_CERTIFICATE = 11;
673         public static final int TYPE_CLIENT_KEY_EXCHANGE = 16;
674 
675         public int type;
676 
677         /**
678          * Parses the provided TLS record as a handshake message.
679          *
680          * @return alert message or {@code null} if the record does not contain a handshake message.
681          */
tryParse(TlsRecord record)682         public static HandshakeMessage tryParse(TlsRecord record) {
683             if (record.protocol != TlsProtocols.HANDSHAKE) {
684                 return null;
685             }
686             if (record.fragment.length < 1) {
687                 return null;
688             }
689             HandshakeMessage result = new HandshakeMessage();
690             result.type = record.fragment[0] & 0xff;
691             return result;
692         }
693     }
694 
695     public static class AlertMessage {
696         public static final int LEVEL_FATAL = 2;
697         public static final int DESCRIPTION_UNEXPECTED_MESSAGE = 10;
698 
699         public int level;
700         public int description;
701 
702         /**
703          * Parses the provided TLS record as an alert message.
704          *
705          * @return alert message or {@code null} if the record does not contain an alert message.
706          */
tryParse(TlsRecord record)707         public static AlertMessage tryParse(TlsRecord record) {
708             if (record.protocol != TlsProtocols.ALERT) {
709                 return null;
710             }
711             if (record.fragment.length < 2) {
712                 return null;
713             }
714             AlertMessage result = new AlertMessage();
715             result.level = record.fragment[0] & 0xff;
716             result.description = record.fragment[1] & 0xff;
717             return result;
718         }
719     }
720 
721     private static abstract class HeartbeatProtocol {
HeartbeatProtocol()722         private HeartbeatProtocol() {}
723 
724         private static final int MESSAGE_TYPE_REQUEST = 1;
725         @SuppressWarnings("unused")
726         private static final int MESSAGE_TYPE_RESPONSE = 2;
727 
728         private static final int MESSAGE_HEADER_LENGTH = 3;
729         private static final int MESSAGE_PADDING_LENGTH = 16;
730     }
731 
createHeartbeatRequestRecord( int versionMajor, int versionMinor, int declaredPayloadLength, byte[] payload)732     private static byte[] createHeartbeatRequestRecord(
733             int versionMajor, int versionMinor,
734             int declaredPayloadLength, byte[] payload) {
735 
736         byte[] fragment = new byte[HeartbeatProtocol.MESSAGE_HEADER_LENGTH
737                 + payload.length + HeartbeatProtocol.MESSAGE_PADDING_LENGTH];
738         fragment[0] = HeartbeatProtocol.MESSAGE_TYPE_REQUEST;
739         putUnsignedShortBigEndian(fragment, 1, declaredPayloadLength); // payload_length
740         TlsRecord record = new TlsRecord();
741         record.protocol = TlsProtocols.HEARTBEAT;
742         record.versionMajor = versionMajor;
743         record.versionMinor = versionMinor;
744         record.fragment = fragment;
745         return TlsRecord.unparse(record);
746     }
747 
748     /**
749      * Reader of TLS records.
750      */
751     public static class TlsRecordReader {
752         private static final int MAX_RECORD_LENGTH = 16384;
753         public static final int RECORD_HEADER_LENGTH = 5;
754 
755         private final InputStream in;
756         private final byte[] buffer;
757         private int firstBufferedByteOffset;
758         private int bufferedByteCount;
759 
TlsRecordReader(InputStream in)760         public TlsRecordReader(InputStream in) {
761             this.in = in;
762             buffer = new byte[MAX_RECORD_LENGTH];
763         }
764 
765         /**
766          * Reads the next TLS record.
767          *
768          * @return TLS record or {@code null} if EOF was encountered before any bytes of a record
769          *         could be read.
770          */
readRecord()771         public byte[] readRecord() throws IOException {
772             // Ensure that a TLS record header (or more) is in the buffer.
773             if (bufferedByteCount < RECORD_HEADER_LENGTH) {
774                 boolean eofPermittedInstead = (bufferedByteCount == 0);
775                 boolean eofEncounteredInstead =
776                         !readAtLeast(RECORD_HEADER_LENGTH, eofPermittedInstead);
777                 if (eofEncounteredInstead) {
778                     // End of stream reached exactly before a TLS record start.
779                     return null;
780                 }
781             }
782 
783             // TLS record header (or more) is in the buffer.
784             // Ensure that the rest of the record is in the buffer.
785             int fragmentLength = getUnsignedShortBigEndian(buffer, firstBufferedByteOffset + 3);
786             int recordLength = RECORD_HEADER_LENGTH + fragmentLength;
787             if (recordLength > MAX_RECORD_LENGTH) {
788                 throw new IOException("TLS record too long: " + recordLength);
789             }
790             if (bufferedByteCount < recordLength) {
791                 readAtLeast(recordLength - bufferedByteCount, false);
792             }
793 
794             // TLS record (or more) is in the buffer.
795             byte[] record = new byte[recordLength];
796             System.arraycopy(buffer, firstBufferedByteOffset, record, 0, recordLength);
797             firstBufferedByteOffset += recordLength;
798             bufferedByteCount -= recordLength;
799             return record;
800         }
801 
802         /**
803          * Reads at least the specified number of bytes from the underlying {@code InputStream} into
804          * the {@code buffer}.
805          *
806          * <p>Bytes buffered but not yet returned to the client in the {@code buffer} are relocated
807          * to the start of the buffer to make space if necessary.
808          *
809          * @param eofPermittedInstead {@code true} if it's permitted for an EOF to be encountered
810          *        without any bytes having been read.
811          *
812          * @return {@code true} if the requested number of bytes (or more) has been read,
813          *         {@code false} if {@code eofPermittedInstead} was {@code true} and EOF was
814          *         encountered when no bytes have yet been read.
815          */
readAtLeast(int size, boolean eofPermittedInstead)816         private boolean readAtLeast(int size, boolean eofPermittedInstead) throws IOException {
817             ensureRemainingBufferCapacityAtLeast(size);
818             boolean firstAttempt = true;
819             while (size > 0) {
820                 int chunkSize = in.read(
821                         buffer,
822                         firstBufferedByteOffset + bufferedByteCount,
823                         buffer.length - (firstBufferedByteOffset + bufferedByteCount));
824                 if (chunkSize == -1) {
825                     if ((firstAttempt) && (eofPermittedInstead)) {
826                         return false;
827                     } else {
828                         throw new EOFException("Premature EOF");
829                     }
830                 }
831                 firstAttempt = false;
832                 bufferedByteCount += chunkSize;
833                 size -= chunkSize;
834             }
835             return true;
836         }
837 
838         /**
839          * Ensures that there is enough capacity in the buffer to store the specified number of
840          * bytes at the {@code firstBufferedByteOffset + bufferedByteCount} offset.
841          */
ensureRemainingBufferCapacityAtLeast(int size)842         private void ensureRemainingBufferCapacityAtLeast(int size) throws IOException {
843             int bufferCapacityRemaining =
844                     buffer.length - (firstBufferedByteOffset + bufferedByteCount);
845             if (bufferCapacityRemaining >= size) {
846                 return;
847             }
848             // Insufficient capacity at the end of the buffer.
849             if (firstBufferedByteOffset > 0) {
850                 // Some of the bytes at the start of the buffer have already been returned to the
851                 // client of this reader. Check if moving the remaining buffered bytes to the start
852                 // of the buffer will make enough space at the end of the buffer.
853                 bufferCapacityRemaining += firstBufferedByteOffset;
854                 if (bufferCapacityRemaining >= size) {
855                     System.arraycopy(buffer, firstBufferedByteOffset, buffer, 0, bufferedByteCount);
856                     firstBufferedByteOffset = 0;
857                     return;
858                 }
859             }
860 
861             throw new IOException("Insuffucient remaining capacity in the buffer. Requested: "
862                     + size + ", remaining: " + bufferCapacityRemaining);
863         }
864     }
865 
getUnsignedShortBigEndian(byte[] buf, int offset)866     private static int getUnsignedShortBigEndian(byte[] buf, int offset) {
867         return ((buf[offset] & 0xff) << 8) | (buf[offset + 1] & 0xff);
868     }
869 
putUnsignedShortBigEndian(byte[] buf, int offset, int value)870     private static void putUnsignedShortBigEndian(byte[] buf, int offset, int value) {
871         buf[offset] = (byte) ((value >>> 8) & 0xff);
872         buf[offset + 1] = (byte) (value & 0xff);
873     }
874 
875     // IMPLEMENTATION NOTE: We can't implement just one closeQueietly(Closeable) because on some
876     // older Android platforms Socket did not implement these interfaces. To make this patch easy to
877     // apply to these older platforms, we declare all the variants of closeQuietly that are needed
878     // without relying on the Closeable interface.
879 
closeQuietly(InputStream in)880     private static void closeQuietly(InputStream in) {
881         if (in != null) {
882             try {
883                 in.close();
884             } catch (IOException ignored) {}
885         }
886     }
887 
closeQuietly(ServerSocket socket)888     public static void closeQuietly(ServerSocket socket) {
889         if (socket != null) {
890             try {
891                 socket.close();
892             } catch (IOException ignored) {}
893         }
894     }
895 
closeQuietly(Socket socket)896     public static void closeQuietly(Socket socket) {
897         if (socket != null) {
898             try {
899                 socket.close();
900             } catch (IOException ignored) {}
901         }
902     }
903 
readResource(Context context, int resId)904     public static byte[] readResource(Context context, int resId) throws IOException {
905         ByteArrayOutputStream result = new ByteArrayOutputStream();
906         InputStream in = null;
907         byte[] buf = new byte[16 * 1024];
908         try {
909             in = context.getResources().openRawResource(resId);
910             int chunkSize;
911             while ((chunkSize = in.read(buf)) != -1) {
912                 result.write(buf, 0, chunkSize);
913             }
914             return result.toByteArray();
915         } finally {
916             closeQuietly(in);
917         }
918     }
919 
920     /**
921      * {@link X509TrustManager} which trusts all certificate chains.
922      */
923     public static class TrustAllX509TrustManager implements X509TrustManager {
924         @Override
checkClientTrusted(X509Certificate[] chain, String authType)925         public void checkClientTrusted(X509Certificate[] chain, String authType)
926                 throws CertificateException {
927         }
928 
929         @Override
checkServerTrusted(X509Certificate[] chain, String authType)930         public void checkServerTrusted(X509Certificate[] chain, String authType)
931                 throws CertificateException {
932         }
933 
934         @Override
getAcceptedIssuers()935         public X509Certificate[] getAcceptedIssuers() {
936             return new X509Certificate[0];
937         }
938     }
939 
940     /**
941      * {@link X509KeyManager} which uses the provided private key and cert chain for all sockets.
942      */
943     public static class HardcodedCertX509KeyManager implements X509KeyManager {
944 
945         private final PrivateKey mPrivateKey;
946         private final X509Certificate[] mCertChain;
947 
HardcodedCertX509KeyManager(PrivateKey privateKey, X509Certificate[] certChain)948         HardcodedCertX509KeyManager(PrivateKey privateKey, X509Certificate[] certChain) {
949             mPrivateKey = privateKey;
950             mCertChain = certChain;
951         }
952 
953         @Override
chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket)954         public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) {
955             return null;
956         }
957 
958         @Override
chooseServerAlias(String keyType, Principal[] issuers, Socket socket)959         public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) {
960             return "singleton";
961         }
962 
963         @Override
getCertificateChain(String alias)964         public X509Certificate[] getCertificateChain(String alias) {
965             return mCertChain;
966         }
967 
968         @Override
getClientAliases(String keyType, Principal[] issuers)969         public String[] getClientAliases(String keyType, Principal[] issuers) {
970             return null;
971         }
972 
973         @Override
getPrivateKey(String alias)974         public PrivateKey getPrivateKey(String alias) {
975             return mPrivateKey;
976         }
977 
978         @Override
getServerAliases(String keyType, Principal[] issuers)979         public String[] getServerAliases(String keyType, Principal[] issuers) {
980             return new String[] {"singleton"};
981         }
982     }
983 }
984