1 /* 2 * Copyright (C) 2014 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.security.cts; 18 19 import android.content.Context; 20 import android.test.InstrumentationTestCase; 21 import android.util.Log; 22 23 import com.android.cts.security.R; 24 25 import java.io.ByteArrayInputStream; 26 import java.io.ByteArrayOutputStream; 27 import java.io.EOFException; 28 import java.io.IOException; 29 import java.io.InputStream; 30 import java.io.OutputStream; 31 import java.net.ServerSocket; 32 import java.net.Socket; 33 import java.net.SocketAddress; 34 import java.security.KeyFactory; 35 import java.security.Principal; 36 import java.security.PrivateKey; 37 import java.security.cert.CertificateException; 38 import java.security.cert.CertificateFactory; 39 import java.security.cert.X509Certificate; 40 import java.security.spec.PKCS8EncodedKeySpec; 41 import java.util.concurrent.Callable; 42 import java.util.concurrent.ExecutionException; 43 import java.util.concurrent.ExecutorService; 44 import java.util.concurrent.Executors; 45 import java.util.concurrent.Future; 46 import java.util.concurrent.TimeUnit; 47 48 import javax.net.ServerSocketFactory; 49 import javax.net.SocketFactory; 50 import javax.net.ssl.KeyManager; 51 import javax.net.ssl.SSLContext; 52 import javax.net.ssl.SSLException; 53 import javax.net.ssl.SSLServerSocket; 54 import javax.net.ssl.SSLSocket; 55 import javax.net.ssl.TrustManager; 56 import javax.net.ssl.X509KeyManager; 57 import javax.net.ssl.X509TrustManager; 58 59 /** 60 * Tests for the OpenSSL Heartbleed vulnerability. 61 */ 62 public class OpenSSLHeartbleedTest extends InstrumentationTestCase { 63 64 // IMPLEMENTATION NOTE: This test spawns an SSLSocket client, SSLServerSocket server, and a 65 // Man-in-The-Middle (MiTM). The client connects to the MiTM which then connects to the server 66 // and starts forwarding all TLS records between the client and the server. In tests that check 67 // for the Heartbleed vulnerability, the MiTM also injects a HeartbeatRequest message into the 68 // traffic. 69 70 // IMPLEMENTATION NOTE: This test spawns several background threads that perform network I/O 71 // on localhost. To ensure that these background threads are cleaned up at the end of the test 72 // tearDown() kills the sockets they may be using. To aid this behavior, all Socket and 73 // ServerSocket instances are available as fields of this class. These fields should be accessed 74 // via setters and getters to avoid memory visibility issues due to concurrency. 75 76 private static final String TAG = OpenSSLHeartbleedTest.class.getSimpleName(); 77 78 private SSLServerSocket mServerListeningSocket; 79 private SSLSocket mServerSocket; 80 private SSLSocket mClientSocket; 81 private ServerSocket mMitmListeningSocket; 82 private Socket mMitmServerSocket; 83 private Socket mMitmClientSocket; 84 private ExecutorService mExecutorService; 85 86 private boolean mHeartbeatRequestWasInjected; 87 private boolean mHeartbeatResponseWasDetetected; 88 private int mFirstDetectedFatalAlertDescription = -1; 89 90 @Override tearDown()91 protected void tearDown() throws Exception { 92 Log.i(TAG, "Tearing down"); 93 if (mExecutorService != null) { 94 mExecutorService.shutdownNow(); 95 } 96 closeQuietly(getServerListeningSocket()); 97 closeQuietly(getServerSocket()); 98 closeQuietly(getClientSocket()); 99 closeQuietly(getMitmListeningSocket()); 100 closeQuietly(getMitmServerSocket()); 101 closeQuietly(getMitmClientSocket()); 102 super.tearDown(); 103 Log.i(TAG, "Tear down completed"); 104 } 105 106 /** 107 * Tests that TLS handshake succeeds when the MiTM simply forwards all data without tampering 108 * with it. This is to catch issues unrelated to TLS heartbeats. 109 */ testWithoutHeartbeats()110 public void testWithoutHeartbeats() throws Exception { 111 handshake(false, false); 112 } 113 114 /** 115 * Tests whether client sockets are vulnerable to Heartbleed. 116 */ testClientHeartbleed()117 public void testClientHeartbleed() throws Exception { 118 checkHeartbleed(true); 119 } 120 121 /** 122 * Tests whether server sockets are vulnerable to Heartbleed. 123 */ testServerHeartbleed()124 public void testServerHeartbleed() throws Exception { 125 checkHeartbleed(false); 126 } 127 128 /** 129 * Tests for Heartbleed. 130 * 131 * @param client {@code true} to test the client, {@code false} to test the server. 132 */ checkHeartbleed(boolean client)133 private void checkHeartbleed(boolean client) throws Exception { 134 // IMPLEMENTATION NOTE: The MiTM is forwarding all TLS records between the client and the 135 // server unmodified. Additionally, the MiTM transmits a malformed HeartbeatRequest to 136 // server (if "client" argument is false) right after client's ClientKeyExchange or to 137 // client (if "client" argument is true) right after server's ServerHello. The peer is 138 // expected to either ignore the HeartbeatRequest (if heartbeats are supported) or to abort 139 // the handshake with unexpected_message alert (if heartbeats are not supported). 140 try { 141 handshake(true, client); 142 } catch (ExecutionException e) { 143 assertFalse( 144 "SSLSocket is vulnerable to Heartbleed in " + ((client) ? "client" : "server") 145 + " mode", 146 wasHeartbeatResponseDetected()); 147 if (e.getCause() instanceof SSLException) { 148 // TLS handshake or data exchange failed. Check whether the error was caused by 149 // fatal alert unexpected_message 150 int alertDescription = getFirstDetectedFatalAlertDescription(); 151 if (alertDescription == -1) { 152 fail("Handshake failed without a fatal alert"); 153 } 154 assertEquals( 155 "First fatal alert description received from server", 156 AlertMessage.DESCRIPTION_UNEXPECTED_MESSAGE, 157 alertDescription); 158 return; 159 } else { 160 throw e; 161 } 162 } 163 164 // TLS handshake succeeded 165 assertFalse( 166 "SSLSocket is vulnerable to Heartbleed in " + ((client) ? "client" : "server") 167 + " mode", 168 wasHeartbeatResponseDetected()); 169 assertTrue("HeartbeatRequest not injected", wasHeartbeatRequestInjected()); 170 } 171 172 /** 173 * Starts the client, server, and the MiTM. Makes the client and server perform a TLS handshake 174 * and exchange application-level data. The MiTM injects a HeartbeatRequest message if requested 175 * by {@code heartbeatRequestInjected}. The direction of the injected message is specified by 176 * {@code injectedIntoClient}. 177 */ handshake( final boolean heartbeatRequestInjected, final boolean injectedIntoClient)178 private void handshake( 179 final boolean heartbeatRequestInjected, 180 final boolean injectedIntoClient) throws Exception { 181 mExecutorService = Executors.newFixedThreadPool(4); 182 setServerListeningSocket(serverBind()); 183 final SocketAddress serverAddress = getServerListeningSocket().getLocalSocketAddress(); 184 Log.i(TAG, "Server bound to " + serverAddress); 185 186 setMitmListeningSocket(mitmBind()); 187 final SocketAddress mitmAddress = getMitmListeningSocket().getLocalSocketAddress(); 188 Log.i(TAG, "MiTM bound to " + mitmAddress); 189 190 // Start the MiTM daemon in the background 191 mExecutorService.submit(new Callable<Void>() { 192 @Override 193 public Void call() throws Exception { 194 mitmAcceptAndForward( 195 serverAddress, 196 heartbeatRequestInjected, 197 injectedIntoClient); 198 return null; 199 } 200 }); 201 // Start the server in the background 202 Future<Void> serverFuture = mExecutorService.submit(new Callable<Void>() { 203 @Override 204 public Void call() throws Exception { 205 serverAcceptAndHandshake(); 206 return null; 207 } 208 }); 209 // Start the client in the background 210 Future<Void> clientFuture = mExecutorService.submit(new Callable<Void>() { 211 @Override 212 public Void call() throws Exception { 213 clientConnectAndHandshake(mitmAddress); 214 return null; 215 } 216 }); 217 218 // Wait for both client and server to terminate, to ensure that we observe all the traffic 219 // exchanged between them. Throw an exception if one of them failed. 220 Log.i(TAG, "Waiting for client"); 221 // Wait for the client, but don't yet throw an exception if it failed. 222 Exception clientException = null; 223 try { 224 clientFuture.get(10, TimeUnit.SECONDS); 225 } catch (Exception e) { 226 clientException = e; 227 } 228 Log.i(TAG, "Waiting for server"); 229 // Wait for the server and throw an exception if it failed. 230 serverFuture.get(5, TimeUnit.SECONDS); 231 // Throw an exception if the client failed. 232 if (clientException != null) { 233 throw clientException; 234 } 235 Log.i(TAG, "Handshake completed and application data exchanged"); 236 } 237 clientConnectAndHandshake(SocketAddress serverAddress)238 private void clientConnectAndHandshake(SocketAddress serverAddress) throws Exception { 239 SSLContext sslContext = SSLContext.getInstance("TLS"); 240 sslContext.init( 241 null, 242 new TrustManager[] {new TrustAllX509TrustManager()}, 243 null); 244 SSLSocket socket = (SSLSocket) sslContext.getSocketFactory().createSocket(); 245 setClientSocket(socket); 246 try { 247 Log.i(TAG, "Client connecting to " + serverAddress); 248 socket.connect(serverAddress); 249 Log.i(TAG, "Client connected to server from " + socket.getLocalSocketAddress()); 250 // Ensure a TLS handshake is performed and an exception is thrown if it fails. 251 socket.getOutputStream().write("client".getBytes()); 252 socket.getOutputStream().flush(); 253 Log.i(TAG, "Client sent request. Reading response"); 254 int b = socket.getInputStream().read(); 255 Log.i(TAG, "Client read response: " + b); 256 } catch (Exception e) { 257 Log.w(TAG, "Client failed", e); 258 throw e; 259 } finally { 260 socket.close(); 261 } 262 } 263 serverBind()264 public SSLServerSocket serverBind() throws Exception { 265 // Load the server's private key and cert chain 266 KeyFactory keyFactory = KeyFactory.getInstance("RSA"); 267 PrivateKey privateKey = keyFactory.generatePrivate(new PKCS8EncodedKeySpec( 268 readResource( 269 getInstrumentation().getContext(), R.raw.openssl_heartbleed_test_key))); 270 CertificateFactory certFactory = CertificateFactory.getInstance("X.509"); 271 X509Certificate[] certChain = new X509Certificate[] { 272 (X509Certificate) certFactory.generateCertificate( 273 new ByteArrayInputStream(readResource( 274 getInstrumentation().getContext(), 275 R.raw.openssl_heartbleed_test_cert))) 276 }; 277 278 // Initialize TLS context to use the private key and cert chain for server sockets 279 SSLContext sslContext = SSLContext.getInstance("TLS"); 280 sslContext.init( 281 new KeyManager[] {new HardcodedCertX509KeyManager(privateKey, certChain)}, 282 null, 283 null); 284 285 Log.i(TAG, "Server binding to local port"); 286 return (SSLServerSocket) sslContext.getServerSocketFactory().createServerSocket(0); 287 } 288 serverAcceptAndHandshake()289 private void serverAcceptAndHandshake() throws Exception { 290 SSLSocket socket = null; 291 SSLServerSocket serverSocket = getServerListeningSocket(); 292 try { 293 Log.i(TAG, "Server listening for incoming connection"); 294 socket = (SSLSocket) serverSocket.accept(); 295 setServerSocket(socket); 296 Log.i(TAG, "Server accepted connection from " + socket.getRemoteSocketAddress()); 297 // Ensure a TLS handshake is performed and an exception is thrown if it fails. 298 socket.getOutputStream().write("server".getBytes()); 299 socket.getOutputStream().flush(); 300 Log.i(TAG, "Server sent reply. Reading response"); 301 int b = socket.getInputStream().read(); 302 Log.i(TAG, "Server read response: " + b); 303 } catch (Exception e) { 304 Log.w(TAG, "Server failed", e); 305 throw e; 306 } finally { 307 if (socket != null) { 308 socket.close(); 309 } 310 } 311 } 312 mitmBind()313 private ServerSocket mitmBind() throws Exception { 314 Log.i(TAG, "MiTM binding to local port"); 315 return ServerSocketFactory.getDefault().createServerSocket(0); 316 } 317 318 /** 319 * Accepts the connection on the MiTM listening socket, forwards the TLS records between the 320 * client and the server, and, if requested, injects a {@code HeartbeatRequest}. 321 * 322 * @param injectHeartbeat whether to inject a {@code HeartbeatRequest} message. 323 * @param injectIntoClient when {@code injectHeartbeat} is {@code true}, whether to inject the 324 * {@code HeartbeatRequest} message into client or into server. 325 */ mitmAcceptAndForward( SocketAddress serverAddress, final boolean injectHeartbeat, final boolean injectIntoClient)326 private void mitmAcceptAndForward( 327 SocketAddress serverAddress, 328 final boolean injectHeartbeat, 329 final boolean injectIntoClient) throws Exception { 330 Socket clientSocket = null; 331 Socket serverSocket = null; 332 ServerSocket listeningSocket = getMitmListeningSocket(); 333 try { 334 Log.i(TAG, "MiTM waiting for incoming connection"); 335 clientSocket = listeningSocket.accept(); 336 setMitmClientSocket(clientSocket); 337 Log.i(TAG, "MiTM accepted connection from " + clientSocket.getRemoteSocketAddress()); 338 serverSocket = SocketFactory.getDefault().createSocket(); 339 setMitmServerSocket(serverSocket); 340 Log.i(TAG, "MiTM connecting to server " + serverAddress); 341 serverSocket.connect(serverAddress, 10000); 342 Log.i(TAG, "MiTM connected to server from " + serverSocket.getLocalSocketAddress()); 343 final InputStream serverInputStream = serverSocket.getInputStream(); 344 final OutputStream clientOutputStream = clientSocket.getOutputStream(); 345 Future<Void> serverToClientTask = mExecutorService.submit(new Callable<Void>() { 346 @Override 347 public Void call() throws Exception { 348 // Inject HeatbeatRequest after ServerHello, if requested 349 forwardTlsRecords( 350 "MiTM S->C", 351 serverInputStream, 352 clientOutputStream, 353 (injectHeartbeat && injectIntoClient) 354 ? HandshakeMessage.TYPE_SERVER_HELLO : -1); 355 return null; 356 } 357 }); 358 // Inject HeatbeatRequest after ClientKeyExchange, if requested 359 forwardTlsRecords( 360 "MiTM C->S", 361 clientSocket.getInputStream(), 362 serverSocket.getOutputStream(), 363 (injectHeartbeat && !injectIntoClient) 364 ? HandshakeMessage.TYPE_CLIENT_KEY_EXCHANGE : -1); 365 serverToClientTask.get(10, TimeUnit.SECONDS); 366 } catch (Exception e) { 367 Log.w(TAG, "MiTM failed", e); 368 throw e; 369 } finally { 370 closeQuietly(clientSocket); 371 closeQuietly(serverSocket); 372 } 373 } 374 375 /** 376 * Forwards TLS records from the provided {@code InputStream} to the provided 377 * {@code OutputStream}. If requested, injects a {@code HeartbeatMessage}. 378 */ forwardTlsRecords( String logPrefix, InputStream in, OutputStream out, int handshakeMessageTypeAfterWhichToInjectHeartbeatRequest)379 private void forwardTlsRecords( 380 String logPrefix, 381 InputStream in, 382 OutputStream out, 383 int handshakeMessageTypeAfterWhichToInjectHeartbeatRequest) throws Exception { 384 Log.i(TAG, logPrefix + ": record forwarding started"); 385 boolean interestingRecordsLogged = 386 handshakeMessageTypeAfterWhichToInjectHeartbeatRequest == -1; 387 try { 388 TlsRecordReader reader = new TlsRecordReader(in); 389 byte[] recordBytes; 390 // Fragments contained in records may be encrypted after a certain point in the 391 // handshake. Once they are encrypted, this MiTM cannot inspect their plaintext which. 392 boolean fragmentEncryptionMayBeEnabled = false; 393 while ((recordBytes = reader.readRecord()) != null) { 394 TlsRecord record = TlsRecord.parse(recordBytes); 395 forwardTlsRecord(logPrefix, 396 recordBytes, 397 record, 398 fragmentEncryptionMayBeEnabled, 399 out, 400 interestingRecordsLogged, 401 handshakeMessageTypeAfterWhichToInjectHeartbeatRequest); 402 if (record.protocol == TlsProtocols.CHANGE_CIPHER_SPEC) { 403 fragmentEncryptionMayBeEnabled = true; 404 } 405 } 406 } catch (Exception e) { 407 Log.w(TAG, logPrefix + ": failed", e); 408 throw e; 409 } finally { 410 Log.d(TAG, logPrefix + ": record forwarding finished"); 411 } 412 } 413 forwardTlsRecord( String logPrefix, byte[] recordBytes, TlsRecord record, boolean fragmentEncryptionMayBeEnabled, OutputStream out, boolean interestingRecordsLogged, int handshakeMessageTypeAfterWhichToInjectHeartbeatRequest)414 private void forwardTlsRecord( 415 String logPrefix, 416 byte[] recordBytes, 417 TlsRecord record, 418 boolean fragmentEncryptionMayBeEnabled, 419 OutputStream out, 420 boolean interestingRecordsLogged, 421 int handshakeMessageTypeAfterWhichToInjectHeartbeatRequest) throws IOException { 422 // Save information about the records if its of interest to this test 423 if (interestingRecordsLogged) { 424 switch (record.protocol) { 425 case TlsProtocols.ALERT: 426 if (!fragmentEncryptionMayBeEnabled) { 427 AlertMessage alert = AlertMessage.tryParse(record); 428 if ((alert != null) && (alert.level == AlertMessage.LEVEL_FATAL)) { 429 setFatalAlertDetected(alert.description); 430 } 431 } 432 break; 433 case TlsProtocols.HEARTBEAT: 434 // When TLS records are encrypted, we cannot determine whether a 435 // heartbeat is a HeartbeatResponse. In our setup, the client and the 436 // server are not expected to sent HeartbeatRequests. Thus, we err on 437 // the side of caution and assume that any heartbeat message sent by 438 // client or server is a HeartbeatResponse. 439 Log.e(TAG, logPrefix 440 + ": heartbeat response detected -- vulnerable to Heartbleed"); 441 setHeartbeatResponseWasDetected(); 442 break; 443 } 444 } 445 446 Log.i(TAG, logPrefix + ": Forwarding TLS record. " 447 + getRecordInfo(record, fragmentEncryptionMayBeEnabled)); 448 out.write(recordBytes); 449 out.flush(); 450 451 // Inject HeartbeatRequest, if necessary, after the specified handshake message type 452 if (handshakeMessageTypeAfterWhichToInjectHeartbeatRequest != -1) { 453 if ((!fragmentEncryptionMayBeEnabled) && (isHandshakeMessageType( 454 record, handshakeMessageTypeAfterWhichToInjectHeartbeatRequest))) { 455 // The Heartbeat Request message below is malformed because its declared 456 // length of payload one byte larger than the actual payload. The peer is 457 // supposed to reject such messages. 458 byte[] payload = "arbitrary".getBytes("US-ASCII"); 459 byte[] heartbeatRequestRecordBytes = createHeartbeatRequestRecord( 460 record.versionMajor, 461 record.versionMinor, 462 payload.length + 1, 463 payload); 464 Log.i(TAG, logPrefix + ": Injecting malformed HeartbeatRequest: " 465 + getRecordInfo( 466 TlsRecord.parse(heartbeatRequestRecordBytes), false)); 467 setHeartbeatRequestWasInjected(); 468 out.write(heartbeatRequestRecordBytes); 469 out.flush(); 470 } 471 } 472 } 473 getRecordInfo(TlsRecord record, boolean mayBeEncrypted)474 private static String getRecordInfo(TlsRecord record, boolean mayBeEncrypted) { 475 StringBuilder result = new StringBuilder(); 476 result.append(getProtocolName(record.protocol)) 477 .append(", ") 478 .append(getFragmentInfo(record, mayBeEncrypted)); 479 return result.toString(); 480 } 481 getProtocolName(int protocol)482 private static String getProtocolName(int protocol) { 483 switch (protocol) { 484 case TlsProtocols.ALERT: 485 return "alert"; 486 case TlsProtocols.APPLICATION_DATA: 487 return "application data"; 488 case TlsProtocols.CHANGE_CIPHER_SPEC: 489 return "change cipher spec"; 490 case TlsProtocols.HANDSHAKE: 491 return "handshake"; 492 case TlsProtocols.HEARTBEAT: 493 return "heatbeat"; 494 default: 495 return String.valueOf(protocol); 496 } 497 } 498 getFragmentInfo(TlsRecord record, boolean mayBeEncrypted)499 private static String getFragmentInfo(TlsRecord record, boolean mayBeEncrypted) { 500 StringBuilder result = new StringBuilder(); 501 if (mayBeEncrypted) { 502 result.append("encrypted?"); 503 } else { 504 switch (record.protocol) { 505 case TlsProtocols.ALERT: 506 result.append("level: " + ((record.fragment.length > 0) 507 ? String.valueOf(record.fragment[0] & 0xff) : "n/a") 508 + ", description: " 509 + ((record.fragment.length > 1) 510 ? String.valueOf(record.fragment[1] & 0xff) : "n/a")); 511 break; 512 case TlsProtocols.APPLICATION_DATA: 513 break; 514 case TlsProtocols.CHANGE_CIPHER_SPEC: 515 result.append("payload: " + ((record.fragment.length > 0) 516 ? String.valueOf(record.fragment[0] & 0xff) : "n/a")); 517 break; 518 case TlsProtocols.HANDSHAKE: 519 result.append("type: " + ((record.fragment.length > 0) 520 ? String.valueOf(record.fragment[0] & 0xff) : "n/a")); 521 break; 522 case TlsProtocols.HEARTBEAT: 523 result.append("type: " + ((record.fragment.length > 0) 524 ? String.valueOf(record.fragment[0] & 0xff) : "n/a") 525 + ", payload length: " 526 + ((record.fragment.length >= 3) 527 ? String.valueOf( 528 getUnsignedShortBigEndian(record.fragment, 1)) 529 : "n/a")); 530 break; 531 } 532 } 533 result.append(", ").append("fragment length: " + record.fragment.length); 534 return result.toString(); 535 } 536 setServerListeningSocket(SSLServerSocket socket)537 private synchronized void setServerListeningSocket(SSLServerSocket socket) { 538 mServerListeningSocket = socket; 539 } 540 getServerListeningSocket()541 private synchronized SSLServerSocket getServerListeningSocket() { 542 return mServerListeningSocket; 543 } 544 setServerSocket(SSLSocket socket)545 private synchronized void setServerSocket(SSLSocket socket) { 546 mServerSocket = socket; 547 } 548 getServerSocket()549 private synchronized SSLSocket getServerSocket() { 550 return mServerSocket; 551 } 552 setClientSocket(SSLSocket socket)553 private synchronized void setClientSocket(SSLSocket socket) { 554 mClientSocket = socket; 555 } 556 getClientSocket()557 private synchronized SSLSocket getClientSocket() { 558 return mClientSocket; 559 } 560 setMitmListeningSocket(ServerSocket socket)561 private synchronized void setMitmListeningSocket(ServerSocket socket) { 562 mMitmListeningSocket = socket; 563 } 564 getMitmListeningSocket()565 private synchronized ServerSocket getMitmListeningSocket() { 566 return mMitmListeningSocket; 567 } 568 setMitmServerSocket(Socket socket)569 private synchronized void setMitmServerSocket(Socket socket) { 570 mMitmServerSocket = socket; 571 } 572 getMitmServerSocket()573 private synchronized Socket getMitmServerSocket() { 574 return mMitmServerSocket; 575 } 576 setMitmClientSocket(Socket socket)577 private synchronized void setMitmClientSocket(Socket socket) { 578 mMitmClientSocket = socket; 579 } 580 getMitmClientSocket()581 private synchronized Socket getMitmClientSocket() { 582 return mMitmClientSocket; 583 } 584 setHeartbeatRequestWasInjected()585 private synchronized void setHeartbeatRequestWasInjected() { 586 mHeartbeatRequestWasInjected = true; 587 } 588 wasHeartbeatRequestInjected()589 private synchronized boolean wasHeartbeatRequestInjected() { 590 return mHeartbeatRequestWasInjected; 591 } 592 setHeartbeatResponseWasDetected()593 private synchronized void setHeartbeatResponseWasDetected() { 594 mHeartbeatResponseWasDetetected = true; 595 } 596 wasHeartbeatResponseDetected()597 private synchronized boolean wasHeartbeatResponseDetected() { 598 return mHeartbeatResponseWasDetetected; 599 } 600 setFatalAlertDetected(int description)601 private synchronized void setFatalAlertDetected(int description) { 602 if (mFirstDetectedFatalAlertDescription == -1) { 603 mFirstDetectedFatalAlertDescription = description; 604 } 605 } 606 getFirstDetectedFatalAlertDescription()607 private synchronized int getFirstDetectedFatalAlertDescription() { 608 return mFirstDetectedFatalAlertDescription; 609 } 610 611 public static abstract class TlsProtocols { 612 public static final int CHANGE_CIPHER_SPEC = 20; 613 public static final int ALERT = 21; 614 public static final int HANDSHAKE = 22; 615 public static final int APPLICATION_DATA = 23; 616 public static final int HEARTBEAT = 24; TlsProtocols()617 private TlsProtocols() {} 618 } 619 620 public static class TlsRecord { 621 public int protocol; 622 public int versionMajor; 623 public int versionMinor; 624 public byte[] fragment; 625 parse(byte[] record)626 public static TlsRecord parse(byte[] record) throws IOException { 627 TlsRecord result = new TlsRecord(); 628 if (record.length < TlsRecordReader.RECORD_HEADER_LENGTH) { 629 throw new IOException("Record too short: " + record.length); 630 } 631 result.protocol = record[0] & 0xff; 632 result.versionMajor = record[1] & 0xff; 633 result.versionMinor = record[2] & 0xff; 634 int fragmentLength = getUnsignedShortBigEndian(record, 3); 635 int actualFragmentLength = record.length - TlsRecordReader.RECORD_HEADER_LENGTH; 636 if (fragmentLength != actualFragmentLength) { 637 throw new IOException("Fragment length mismatch. Expected: " + fragmentLength 638 + ", actual: " + actualFragmentLength); 639 } 640 result.fragment = new byte[fragmentLength]; 641 System.arraycopy( 642 record, TlsRecordReader.RECORD_HEADER_LENGTH, 643 result.fragment, 0, 644 fragmentLength); 645 return result; 646 } 647 unparse(TlsRecord record)648 public static byte[] unparse(TlsRecord record) { 649 byte[] result = new byte[TlsRecordReader.RECORD_HEADER_LENGTH + record.fragment.length]; 650 result[0] = (byte) record.protocol; 651 result[1] = (byte) record.versionMajor; 652 result[2] = (byte) record.versionMinor; 653 putUnsignedShortBigEndian(result, 3, record.fragment.length); 654 System.arraycopy( 655 record.fragment, 0, 656 result, TlsRecordReader.RECORD_HEADER_LENGTH, 657 record.fragment.length); 658 return result; 659 } 660 } 661 isHandshakeMessageType(TlsRecord record, int type)662 public static final boolean isHandshakeMessageType(TlsRecord record, int type) { 663 HandshakeMessage handshake = HandshakeMessage.tryParse(record); 664 if (handshake == null) { 665 return false; 666 } 667 return handshake.type == type; 668 } 669 670 public static class HandshakeMessage { 671 public static final int TYPE_SERVER_HELLO = 2; 672 public static final int TYPE_CERTIFICATE = 11; 673 public static final int TYPE_CLIENT_KEY_EXCHANGE = 16; 674 675 public int type; 676 677 /** 678 * Parses the provided TLS record as a handshake message. 679 * 680 * @return alert message or {@code null} if the record does not contain a handshake message. 681 */ tryParse(TlsRecord record)682 public static HandshakeMessage tryParse(TlsRecord record) { 683 if (record.protocol != TlsProtocols.HANDSHAKE) { 684 return null; 685 } 686 if (record.fragment.length < 1) { 687 return null; 688 } 689 HandshakeMessage result = new HandshakeMessage(); 690 result.type = record.fragment[0] & 0xff; 691 return result; 692 } 693 } 694 695 public static class AlertMessage { 696 public static final int LEVEL_FATAL = 2; 697 public static final int DESCRIPTION_UNEXPECTED_MESSAGE = 10; 698 699 public int level; 700 public int description; 701 702 /** 703 * Parses the provided TLS record as an alert message. 704 * 705 * @return alert message or {@code null} if the record does not contain an alert message. 706 */ tryParse(TlsRecord record)707 public static AlertMessage tryParse(TlsRecord record) { 708 if (record.protocol != TlsProtocols.ALERT) { 709 return null; 710 } 711 if (record.fragment.length < 2) { 712 return null; 713 } 714 AlertMessage result = new AlertMessage(); 715 result.level = record.fragment[0] & 0xff; 716 result.description = record.fragment[1] & 0xff; 717 return result; 718 } 719 } 720 721 private static abstract class HeartbeatProtocol { HeartbeatProtocol()722 private HeartbeatProtocol() {} 723 724 private static final int MESSAGE_TYPE_REQUEST = 1; 725 @SuppressWarnings("unused") 726 private static final int MESSAGE_TYPE_RESPONSE = 2; 727 728 private static final int MESSAGE_HEADER_LENGTH = 3; 729 private static final int MESSAGE_PADDING_LENGTH = 16; 730 } 731 createHeartbeatRequestRecord( int versionMajor, int versionMinor, int declaredPayloadLength, byte[] payload)732 private static byte[] createHeartbeatRequestRecord( 733 int versionMajor, int versionMinor, 734 int declaredPayloadLength, byte[] payload) { 735 736 byte[] fragment = new byte[HeartbeatProtocol.MESSAGE_HEADER_LENGTH 737 + payload.length + HeartbeatProtocol.MESSAGE_PADDING_LENGTH]; 738 fragment[0] = HeartbeatProtocol.MESSAGE_TYPE_REQUEST; 739 putUnsignedShortBigEndian(fragment, 1, declaredPayloadLength); // payload_length 740 TlsRecord record = new TlsRecord(); 741 record.protocol = TlsProtocols.HEARTBEAT; 742 record.versionMajor = versionMajor; 743 record.versionMinor = versionMinor; 744 record.fragment = fragment; 745 return TlsRecord.unparse(record); 746 } 747 748 /** 749 * Reader of TLS records. 750 */ 751 public static class TlsRecordReader { 752 private static final int MAX_RECORD_LENGTH = 16384; 753 public static final int RECORD_HEADER_LENGTH = 5; 754 755 private final InputStream in; 756 private final byte[] buffer; 757 private int firstBufferedByteOffset; 758 private int bufferedByteCount; 759 TlsRecordReader(InputStream in)760 public TlsRecordReader(InputStream in) { 761 this.in = in; 762 buffer = new byte[MAX_RECORD_LENGTH]; 763 } 764 765 /** 766 * Reads the next TLS record. 767 * 768 * @return TLS record or {@code null} if EOF was encountered before any bytes of a record 769 * could be read. 770 */ readRecord()771 public byte[] readRecord() throws IOException { 772 // Ensure that a TLS record header (or more) is in the buffer. 773 if (bufferedByteCount < RECORD_HEADER_LENGTH) { 774 boolean eofPermittedInstead = (bufferedByteCount == 0); 775 boolean eofEncounteredInstead = 776 !readAtLeast(RECORD_HEADER_LENGTH, eofPermittedInstead); 777 if (eofEncounteredInstead) { 778 // End of stream reached exactly before a TLS record start. 779 return null; 780 } 781 } 782 783 // TLS record header (or more) is in the buffer. 784 // Ensure that the rest of the record is in the buffer. 785 int fragmentLength = getUnsignedShortBigEndian(buffer, firstBufferedByteOffset + 3); 786 int recordLength = RECORD_HEADER_LENGTH + fragmentLength; 787 if (recordLength > MAX_RECORD_LENGTH) { 788 throw new IOException("TLS record too long: " + recordLength); 789 } 790 if (bufferedByteCount < recordLength) { 791 readAtLeast(recordLength - bufferedByteCount, false); 792 } 793 794 // TLS record (or more) is in the buffer. 795 byte[] record = new byte[recordLength]; 796 System.arraycopy(buffer, firstBufferedByteOffset, record, 0, recordLength); 797 firstBufferedByteOffset += recordLength; 798 bufferedByteCount -= recordLength; 799 return record; 800 } 801 802 /** 803 * Reads at least the specified number of bytes from the underlying {@code InputStream} into 804 * the {@code buffer}. 805 * 806 * <p>Bytes buffered but not yet returned to the client in the {@code buffer} are relocated 807 * to the start of the buffer to make space if necessary. 808 * 809 * @param eofPermittedInstead {@code true} if it's permitted for an EOF to be encountered 810 * without any bytes having been read. 811 * 812 * @return {@code true} if the requested number of bytes (or more) has been read, 813 * {@code false} if {@code eofPermittedInstead} was {@code true} and EOF was 814 * encountered when no bytes have yet been read. 815 */ readAtLeast(int size, boolean eofPermittedInstead)816 private boolean readAtLeast(int size, boolean eofPermittedInstead) throws IOException { 817 ensureRemainingBufferCapacityAtLeast(size); 818 boolean firstAttempt = true; 819 while (size > 0) { 820 int chunkSize = in.read( 821 buffer, 822 firstBufferedByteOffset + bufferedByteCount, 823 buffer.length - (firstBufferedByteOffset + bufferedByteCount)); 824 if (chunkSize == -1) { 825 if ((firstAttempt) && (eofPermittedInstead)) { 826 return false; 827 } else { 828 throw new EOFException("Premature EOF"); 829 } 830 } 831 firstAttempt = false; 832 bufferedByteCount += chunkSize; 833 size -= chunkSize; 834 } 835 return true; 836 } 837 838 /** 839 * Ensures that there is enough capacity in the buffer to store the specified number of 840 * bytes at the {@code firstBufferedByteOffset + bufferedByteCount} offset. 841 */ ensureRemainingBufferCapacityAtLeast(int size)842 private void ensureRemainingBufferCapacityAtLeast(int size) throws IOException { 843 int bufferCapacityRemaining = 844 buffer.length - (firstBufferedByteOffset + bufferedByteCount); 845 if (bufferCapacityRemaining >= size) { 846 return; 847 } 848 // Insufficient capacity at the end of the buffer. 849 if (firstBufferedByteOffset > 0) { 850 // Some of the bytes at the start of the buffer have already been returned to the 851 // client of this reader. Check if moving the remaining buffered bytes to the start 852 // of the buffer will make enough space at the end of the buffer. 853 bufferCapacityRemaining += firstBufferedByteOffset; 854 if (bufferCapacityRemaining >= size) { 855 System.arraycopy(buffer, firstBufferedByteOffset, buffer, 0, bufferedByteCount); 856 firstBufferedByteOffset = 0; 857 return; 858 } 859 } 860 861 throw new IOException("Insuffucient remaining capacity in the buffer. Requested: " 862 + size + ", remaining: " + bufferCapacityRemaining); 863 } 864 } 865 getUnsignedShortBigEndian(byte[] buf, int offset)866 private static int getUnsignedShortBigEndian(byte[] buf, int offset) { 867 return ((buf[offset] & 0xff) << 8) | (buf[offset + 1] & 0xff); 868 } 869 putUnsignedShortBigEndian(byte[] buf, int offset, int value)870 private static void putUnsignedShortBigEndian(byte[] buf, int offset, int value) { 871 buf[offset] = (byte) ((value >>> 8) & 0xff); 872 buf[offset + 1] = (byte) (value & 0xff); 873 } 874 875 // IMPLEMENTATION NOTE: We can't implement just one closeQueietly(Closeable) because on some 876 // older Android platforms Socket did not implement these interfaces. To make this patch easy to 877 // apply to these older platforms, we declare all the variants of closeQuietly that are needed 878 // without relying on the Closeable interface. 879 closeQuietly(InputStream in)880 private static void closeQuietly(InputStream in) { 881 if (in != null) { 882 try { 883 in.close(); 884 } catch (IOException ignored) {} 885 } 886 } 887 closeQuietly(ServerSocket socket)888 public static void closeQuietly(ServerSocket socket) { 889 if (socket != null) { 890 try { 891 socket.close(); 892 } catch (IOException ignored) {} 893 } 894 } 895 closeQuietly(Socket socket)896 public static void closeQuietly(Socket socket) { 897 if (socket != null) { 898 try { 899 socket.close(); 900 } catch (IOException ignored) {} 901 } 902 } 903 readResource(Context context, int resId)904 public static byte[] readResource(Context context, int resId) throws IOException { 905 ByteArrayOutputStream result = new ByteArrayOutputStream(); 906 InputStream in = null; 907 byte[] buf = new byte[16 * 1024]; 908 try { 909 in = context.getResources().openRawResource(resId); 910 int chunkSize; 911 while ((chunkSize = in.read(buf)) != -1) { 912 result.write(buf, 0, chunkSize); 913 } 914 return result.toByteArray(); 915 } finally { 916 closeQuietly(in); 917 } 918 } 919 920 /** 921 * {@link X509TrustManager} which trusts all certificate chains. 922 */ 923 public static class TrustAllX509TrustManager implements X509TrustManager { 924 @Override checkClientTrusted(X509Certificate[] chain, String authType)925 public void checkClientTrusted(X509Certificate[] chain, String authType) 926 throws CertificateException { 927 } 928 929 @Override checkServerTrusted(X509Certificate[] chain, String authType)930 public void checkServerTrusted(X509Certificate[] chain, String authType) 931 throws CertificateException { 932 } 933 934 @Override getAcceptedIssuers()935 public X509Certificate[] getAcceptedIssuers() { 936 return new X509Certificate[0]; 937 } 938 } 939 940 /** 941 * {@link X509KeyManager} which uses the provided private key and cert chain for all sockets. 942 */ 943 public static class HardcodedCertX509KeyManager implements X509KeyManager { 944 945 private final PrivateKey mPrivateKey; 946 private final X509Certificate[] mCertChain; 947 HardcodedCertX509KeyManager(PrivateKey privateKey, X509Certificate[] certChain)948 HardcodedCertX509KeyManager(PrivateKey privateKey, X509Certificate[] certChain) { 949 mPrivateKey = privateKey; 950 mCertChain = certChain; 951 } 952 953 @Override chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket)954 public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { 955 return null; 956 } 957 958 @Override chooseServerAlias(String keyType, Principal[] issuers, Socket socket)959 public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { 960 return "singleton"; 961 } 962 963 @Override getCertificateChain(String alias)964 public X509Certificate[] getCertificateChain(String alias) { 965 return mCertChain; 966 } 967 968 @Override getClientAliases(String keyType, Principal[] issuers)969 public String[] getClientAliases(String keyType, Principal[] issuers) { 970 return null; 971 } 972 973 @Override getPrivateKey(String alias)974 public PrivateKey getPrivateKey(String alias) { 975 return mPrivateKey; 976 } 977 978 @Override getServerAliases(String keyType, Principal[] issuers)979 public String[] getServerAliases(String keyType, Principal[] issuers) { 980 return new String[] {"singleton"}; 981 } 982 } 983 } 984