1 /*
2  * Copyright (C) 2011 Google Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.mockwebserver;
18 
19 import static com.google.mockwebserver.SocketPolicy.DISCONNECT_AT_START;
20 import static com.google.mockwebserver.SocketPolicy.FAIL_HANDSHAKE;
21 import java.io.BufferedInputStream;
22 import java.io.BufferedOutputStream;
23 import java.io.ByteArrayOutputStream;
24 import java.io.IOException;
25 import java.io.InputStream;
26 import java.io.OutputStream;
27 import java.net.InetAddress;
28 import java.net.InetSocketAddress;
29 import java.net.MalformedURLException;
30 import java.net.Proxy;
31 import java.net.ServerSocket;
32 import java.net.Socket;
33 import java.net.SocketException;
34 import java.net.URL;
35 import java.net.UnknownHostException;
36 import java.nio.charset.StandardCharsets;
37 import java.security.SecureRandom;
38 import java.security.cert.CertificateException;
39 import java.security.cert.X509Certificate;
40 import java.util.ArrayList;
41 import java.util.Iterator;
42 import java.util.List;
43 import java.util.Locale;
44 import java.util.Map;
45 import java.util.concurrent.BlockingQueue;
46 import java.util.concurrent.ConcurrentHashMap;
47 import java.util.concurrent.ExecutorService;
48 import java.util.concurrent.Executors;
49 import java.util.concurrent.LinkedBlockingQueue;
50 import java.util.concurrent.atomic.AtomicInteger;
51 import java.util.logging.Level;
52 import java.util.logging.Logger;
53 import javax.net.ssl.SSLContext;
54 import javax.net.ssl.SSLSocket;
55 import javax.net.ssl.SSLSocketFactory;
56 import javax.net.ssl.TrustManager;
57 import javax.net.ssl.X509TrustManager;
58 
59 /**
60  * A scriptable web server. Callers supply canned responses and the server
61  * replays them upon request in sequence.
62  */
63 public final class MockWebServer {
64     private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() {
65         @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
66                 throws CertificateException {
67             throw new CertificateException();
68         }
69 
70         @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
71             throw new AssertionError();
72         }
73 
74         @Override public X509Certificate[] getAcceptedIssuers() {
75             throw new AssertionError();
76         }
77     };
78 
79     private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
80 
81     private final BlockingQueue<RecordedRequest> requestQueue
82             = new LinkedBlockingQueue<RecordedRequest>();
83 
84     /** All map values are Boolean.TRUE. (Collections.newSetFromMap isn't available in Froyo) */
85     private final Map<Socket, Boolean> openClientSockets = new ConcurrentHashMap<Socket, Boolean>();
86     private final AtomicInteger requestCount = new AtomicInteger();
87     private int bodyLimit = Integer.MAX_VALUE;
88     private ServerSocket serverSocket;
89     private SSLSocketFactory sslSocketFactory;
90     private ExecutorService acceptExecutor;
91     private ExecutorService requestExecutor;
92     private boolean tunnelProxy;
93     private Dispatcher dispatcher = new QueueDispatcher();
94 
95     private int port = -1;
96     private int workerThreads = Integer.MAX_VALUE;
97 
getPort()98     public int getPort() {
99         if (port == -1) {
100             throw new IllegalStateException("Cannot retrieve port before calling play()");
101         }
102         return port;
103     }
104 
getHostName()105     public String getHostName() {
106         try {
107             return InetAddress.getLocalHost().getHostName();
108         } catch (UnknownHostException e) {
109             throw new AssertionError(e);
110         }
111     }
112 
toProxyAddress()113     public Proxy toProxyAddress() {
114         return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(getHostName(), getPort()));
115     }
116 
117     /**
118      * Returns a URL for connecting to this server.
119      *
120      * @param path the request path, such as "/".
121      */
getUrl(String path)122     public URL getUrl(String path) {
123         try {
124             return sslSocketFactory != null
125                     ? new URL("https://" + getHostName() + ":" + getPort() + path)
126                     : new URL("http://" + getHostName() + ":" + getPort() + path);
127         } catch (MalformedURLException e) {
128             throw new AssertionError(e);
129         }
130     }
131 
132     /**
133      * Returns a cookie domain for this server. This returns the server's
134      * non-loopback host name if it is known. Otherwise this returns ".local"
135      * for this server's loopback name.
136      */
getCookieDomain()137     public String getCookieDomain() {
138         String hostName = getHostName();
139         return hostName.contains(".") ? hostName : ".local";
140     }
141 
setWorkerThreads(int threads)142     public void setWorkerThreads(int threads) {
143         this.workerThreads = threads;
144     }
145 
146     /**
147      * Sets the number of bytes of the POST body to keep in memory to the given
148      * limit.
149      */
setBodyLimit(int maxBodyLength)150     public void setBodyLimit(int maxBodyLength) {
151         this.bodyLimit = maxBodyLength;
152     }
153 
154     /**
155      * Serve requests with HTTPS rather than otherwise.
156      *
157      * @param tunnelProxy whether to expect the HTTP CONNECT method before
158      *     negotiating TLS.
159      */
useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy)160     public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) {
161         this.sslSocketFactory = sslSocketFactory;
162         this.tunnelProxy = tunnelProxy;
163     }
164 
165     /**
166      * Awaits the next HTTP request, removes it, and returns it. Callers should
167      * use this to verify the request sent was as intended.
168      */
takeRequest()169     public RecordedRequest takeRequest() throws InterruptedException {
170         return requestQueue.take();
171     }
172 
173     /**
174      * Returns the number of HTTP requests received thus far by this server.
175      * This may exceed the number of HTTP connections when connection reuse is
176      * in practice.
177      */
getRequestCount()178     public int getRequestCount() {
179         return requestCount.get();
180     }
181 
182     /**
183      * Scripts {@code response} to be returned to a request made in sequence.
184      * The first request is served by the first enqueued response; the second
185      * request by the second enqueued response; and so on.
186      *
187      * @throws ClassCastException if the default dispatcher has been replaced
188      *     with {@link #setDispatcher(Dispatcher)}.
189      */
enqueue(MockResponse response)190     public void enqueue(MockResponse response) {
191         ((QueueDispatcher) dispatcher).enqueueResponse(response.clone());
192     }
193 
194     /**
195      * Equivalent to {@code play(0)}.
196      */
play()197     public void play() throws IOException {
198         play(0);
199     }
200 
201     /**
202      * Starts the server, serves all enqueued requests, and shuts the server
203      * down.
204      *
205      * @param port the port to listen to, or 0 for any available port.
206      *     Automated tests should always use port 0 to avoid flakiness when a
207      *     specific port is unavailable.
208      */
play(int port)209     public void play(int port) throws IOException {
210         if (acceptExecutor != null) {
211             throw new IllegalStateException("play() already called");
212         }
213         // The acceptExecutor handles the Socket.accept() and hands each request off to the
214         // requestExecutor. It also handles shutdown.
215         acceptExecutor = Executors.newSingleThreadExecutor();
216         // The requestExecutor has a fixed number of worker threads. In order to get strict
217         // guarantees that requests are handled in the order in which they are accepted
218         // workerThreads should be set to 1.
219         requestExecutor = Executors.newFixedThreadPool(workerThreads);
220         serverSocket = new ServerSocket(port);
221         serverSocket.setReuseAddress(true);
222 
223         this.port = serverSocket.getLocalPort();
224         acceptExecutor.execute(namedRunnable("MockWebServer-accept-" + port, new Runnable() {
225             public void run() {
226                 try {
227                     acceptConnections();
228                 } catch (Throwable e) {
229                     logger.log(Level.WARNING, "MockWebServer connection failed", e);
230                 }
231 
232                 /*
233                  * This gnarly block of code will release all sockets and
234                  * all thread, even if any close fails.
235                  */
236                 try {
237                     serverSocket.close();
238                 } catch (Throwable e) {
239                     logger.log(Level.WARNING, "MockWebServer server socket close failed", e);
240                 }
241                 for (Iterator<Socket> s = openClientSockets.keySet().iterator(); s.hasNext(); ) {
242                     try {
243                         s.next().close();
244                         s.remove();
245                     } catch (Throwable e) {
246                         logger.log(Level.WARNING, "MockWebServer socket close failed", e);
247                     }
248                 }
249                 try {
250                     acceptExecutor.shutdown();
251                 } catch (Throwable e) {
252                     logger.log(Level.WARNING, "MockWebServer acceptExecutor shutdown failed", e);
253                 }
254                 try {
255                     requestExecutor.shutdown();
256                 } catch (Throwable e) {
257                     logger.log(Level.WARNING, "MockWebServer requestExecutor shutdown failed", e);
258                 }
259             }
260 
261             private void acceptConnections() throws Exception {
262                 while (true) {
263                     Socket socket;
264                     try {
265                         socket = serverSocket.accept();
266                     } catch (SocketException e) {
267                         return;
268                     }
269                     SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
270                     if (socketPolicy == DISCONNECT_AT_START) {
271                         dispatchBookkeepingRequest(0, socket);
272                         socket.close();
273                     } else {
274                         openClientSockets.put(socket, true);
275                         serveConnection(socket);
276                     }
277                 }
278             }
279         }));
280     }
281 
shutdown()282     public void shutdown() throws IOException {
283         if (serverSocket != null) {
284             serverSocket.close(); // should cause acceptConnections() to break out
285         }
286     }
287 
serveConnection(final Socket raw)288     private void serveConnection(final Socket raw) {
289         String name = "MockWebServer-" + raw.getRemoteSocketAddress();
290         requestExecutor.execute(namedRunnable(name, new Runnable() {
291             int sequenceNumber = 0;
292 
293             public void run() {
294                 try {
295                     processConnection();
296                 } catch (Exception e) {
297                     logger.log(Level.WARNING, "MockWebServer connection failed", e);
298                 }
299             }
300 
301             public void processConnection() throws Exception {
302                 Socket socket;
303                 if (sslSocketFactory != null) {
304                     if (tunnelProxy) {
305                         createTunnel();
306                     }
307                     SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
308                     if (socketPolicy == FAIL_HANDSHAKE) {
309                         dispatchBookkeepingRequest(sequenceNumber, raw);
310                         processHandshakeFailure(raw);
311                         return;
312                     }
313                     socket = sslSocketFactory.createSocket(
314                             raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
315                     SSLSocket sslSocket = (SSLSocket) socket;
316                     sslSocket.setUseClientMode(false);
317                     openClientSockets.put(socket, true);
318 
319                     sslSocket.startHandshake();
320 
321                     openClientSockets.remove(raw);
322                 } else {
323                     socket = raw;
324                 }
325 
326                 InputStream in = new BufferedInputStream(socket.getInputStream());
327                 OutputStream out = new BufferedOutputStream(socket.getOutputStream());
328 
329                 while (processOneRequest(socket, in, out)) {
330                 }
331 
332                 if (sequenceNumber == 0) {
333                     logger.warning("MockWebServer connection didn't make a request");
334                 }
335 
336                 in.close();
337                 out.close();
338                 socket.close();
339                 openClientSockets.remove(socket);
340             }
341 
342             /**
343              * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response
344              * is dispatched.
345              */
346             private void createTunnel() throws IOException, InterruptedException {
347                 while (true) {
348                     SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
349                     if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) {
350                         throw new IllegalStateException("Tunnel without any CONNECT!");
351                     }
352                     if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return;
353                 }
354             }
355 
356             /**
357              * Reads a request and writes its response. Returns true if a request
358              * was processed.
359              */
360             private boolean processOneRequest(Socket socket, InputStream in, OutputStream out)
361                     throws IOException, InterruptedException {
362                 RecordedRequest request = readRequest(socket, in, out, sequenceNumber);
363                 if (request == null) {
364                     return false;
365                 }
366                 requestCount.incrementAndGet();
367                 requestQueue.add(request);
368                 MockResponse response = dispatcher.dispatch(request);
369                 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AFTER_READING_REQUEST) {
370                   logger.info("Received request: " + request + " and disconnected without responding");
371                   return false;
372                 }
373                 writeResponse(out, response);
374 
375                 // For socket policies that poison the socket after the response is written:
376                 // The client has received the response and will no longer be blocked after
377                 // writeResponse() has returned. A client can then re-use the connection before
378                 // the socket is poisoned (i.e. keep-alive / connection pooling). The second
379                 // request/response may fail at the beginning, middle, end, or even succeed
380                 // depending on scheduling. Delays can be required in tests to improve the chances
381                 // of sockets being in a known state when subsequent requests are made.
382                 //
383                 // For SHUTDOWN_OUTPUT_AT_END the client may detect a problem with its input socket
384                 // after the request has been made but before the server has chosen a response.
385                 // For clients that perform retries, this can cause the client to issue a retry
386                 // request. The retry handler may call dispatcher.dispatch(request) before the
387                 // initial, failed request handler does and cause non-obvious response ordering.
388                 // Setting workerThreads = 1 ensures that the dispatcher is called for requests in
389                 // the order they are received.
390 
391                 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
392                     in.close();
393                     out.close();
394                 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) {
395                     socket.shutdownInput();
396                 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) {
397                     socket.shutdownOutput();
398                 }
399                 logger.info("Received request: " + request + " and responded: " + response);
400                 sequenceNumber++;
401                 return true;
402             }
403         }));
404     }
405 
processHandshakeFailure(Socket raw)406     private void processHandshakeFailure(Socket raw) throws Exception {
407         SSLContext context = SSLContext.getInstance("TLS");
408         context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom());
409         SSLSocketFactory sslSocketFactory = context.getSocketFactory();
410         SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
411                 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
412         try {
413             socket.startHandshake(); // we're testing a handshake failure
414             throw new AssertionError();
415         } catch (IOException expected) {
416         }
417         socket.close();
418     }
419 
dispatchBookkeepingRequest(int sequenceNumber, Socket socket)420     private void dispatchBookkeepingRequest(int sequenceNumber, Socket socket) throws InterruptedException {
421         requestCount.incrementAndGet();
422         RecordedRequest request = new RecordedRequest(null, null, null, -1, null, sequenceNumber,
423                 socket);
424         dispatcher.dispatch(request);
425     }
426 
427     /** @param sequenceNumber the index of this request on this connection. */
readRequest(Socket socket, InputStream in, OutputStream out, int sequenceNumber)428     private RecordedRequest readRequest(Socket socket, InputStream in, OutputStream out,
429             int sequenceNumber) throws IOException {
430         String request;
431         try {
432             request = readAsciiUntilCrlf(in);
433         } catch (IOException streamIsClosed) {
434             return null; // no request because we closed the stream
435         }
436         if (request.length() == 0) {
437             return null; // no request because the stream is exhausted
438         }
439 
440         List<String> headers = new ArrayList<String>();
441         long contentLength = -1;
442         boolean chunked = false;
443         boolean expectContinue = false;
444         String header;
445         while ((header = readAsciiUntilCrlf(in)).length() != 0) {
446             headers.add(header);
447             String lowercaseHeader = header.toLowerCase(Locale.US);
448             if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
449                 contentLength = Long.parseLong(header.substring(15).trim());
450             }
451             if (lowercaseHeader.startsWith("transfer-encoding:")
452                     && lowercaseHeader.substring(18).trim().equals("chunked")) {
453                 chunked = true;
454             }
455             if (lowercaseHeader.startsWith("expect:")
456                     && lowercaseHeader.substring(7).trim().equals("100-continue")) {
457                 expectContinue = true;
458             }
459         }
460 
461         if (expectContinue) {
462             out.write(("HTTP/1.1 100 Continue\r\n").getBytes(StandardCharsets.US_ASCII));
463             out.write(("Content-Length: 0\r\n").getBytes(StandardCharsets.US_ASCII));
464             out.write(("\r\n").getBytes(StandardCharsets.US_ASCII));
465             out.flush();
466         }
467 
468         boolean hasBody = false;
469         TruncatingOutputStream requestBody = new TruncatingOutputStream();
470         List<Integer> chunkSizes = new ArrayList<Integer>();
471         MockResponse throttlePolicy = dispatcher.peek();
472         if (contentLength != -1) {
473             hasBody = true;
474             throttledTransfer(throttlePolicy, in, requestBody, contentLength);
475         } else if (chunked) {
476             hasBody = true;
477             while (true) {
478                 int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16);
479                 if (chunkSize == 0) {
480                     readEmptyLine(in);
481                     break;
482                 }
483                 chunkSizes.add(chunkSize);
484                 throttledTransfer(throttlePolicy, in, requestBody, chunkSize);
485                 readEmptyLine(in);
486             }
487         }
488 
489         if (request.startsWith("OPTIONS ")
490                 || request.startsWith("GET ")
491                 || request.startsWith("HEAD ")
492                 || request.startsWith("TRACE ")
493                 || request.startsWith("CONNECT ")) {
494             if (hasBody) {
495                 throw new IllegalArgumentException("Request must not have a body: " + request);
496             }
497         } else if (!request.startsWith("POST ")
498                 && !request.startsWith("PUT ")
499                 && !request.startsWith("PATCH ")
500                 && !request.startsWith("DELETE ")) { // Permitted as spec is ambiguous.
501             throw new UnsupportedOperationException("Unexpected method: " + request);
502         }
503 
504         return new RecordedRequest(request, headers, chunkSizes, requestBody.numBytesReceived,
505                 requestBody.toByteArray(), sequenceNumber, socket);
506     }
507 
writeResponse(OutputStream out, MockResponse response)508     private void writeResponse(OutputStream out, MockResponse response) throws IOException {
509         out.write((response.getStatus() + "\r\n").getBytes(StandardCharsets.US_ASCII));
510         List<String> headers = response.getHeaders();
511         for (int i = 0, size = headers.size(); i < size; i++) {
512             String header = headers.get(i);
513             out.write((header + "\r\n").getBytes(StandardCharsets.US_ASCII));
514         }
515         out.write(("\r\n").getBytes(StandardCharsets.US_ASCII));
516         out.flush();
517 
518         InputStream in = response.getBodyStream();
519         if (in == null) return;
520         throttledTransfer(response, in, out, Long.MAX_VALUE);
521     }
522 
523     /**
524      * Transfer bytes from {@code in} to {@code out} until either {@code length}
525      * bytes have been transferred or {@code in} is exhausted. The transfer is
526      * throttled according to {@code throttlePolicy}.
527      */
throttledTransfer(MockResponse throttlePolicy, InputStream in, OutputStream out, long limit)528     private void throttledTransfer(MockResponse throttlePolicy, InputStream in, OutputStream out,
529             long limit) throws IOException {
530         byte[] buffer = new byte[1024];
531         int bytesPerPeriod = throttlePolicy.getThrottleBytesPerPeriod();
532         long delayMs = throttlePolicy.getThrottleUnit().toMillis(throttlePolicy.getThrottlePeriod());
533 
534         while (true) {
535             for (int b = 0; b < bytesPerPeriod; ) {
536                 int toRead = (int) Math.min(Math.min(buffer.length, limit), bytesPerPeriod - b);
537                 int read = in.read(buffer, 0, toRead);
538                 if (read == -1) return;
539 
540                 out.write(buffer, 0, read);
541                 out.flush();
542                 b += read;
543                 limit -= read;
544 
545                 if (limit == 0) return;
546             }
547 
548             try {
549                 if (delayMs != 0) Thread.sleep(delayMs);
550             } catch (InterruptedException e) {
551                 throw new AssertionError();
552             }
553         }
554     }
555 
556     /**
557      * Returns the text from {@code in} until the next "\r\n", or null if
558      * {@code in} is exhausted.
559      */
readAsciiUntilCrlf(InputStream in)560     private String readAsciiUntilCrlf(InputStream in) throws IOException {
561         StringBuilder builder = new StringBuilder();
562         while (true) {
563             int c = in.read();
564             if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') {
565                 builder.deleteCharAt(builder.length() - 1);
566                 return builder.toString();
567             } else if (c == -1) {
568                 return builder.toString();
569             } else {
570                 builder.append((char) c);
571             }
572         }
573     }
574 
readEmptyLine(InputStream in)575     private void readEmptyLine(InputStream in) throws IOException {
576         String line = readAsciiUntilCrlf(in);
577         if (line.length() != 0) {
578             throw new IllegalStateException("Expected empty but was: " + line);
579         }
580     }
581 
582     /**
583      * Sets the dispatcher used to match incoming requests to mock responses.
584      * The default dispatcher simply serves a fixed sequence of responses from
585      * a {@link #enqueue(MockResponse) queue}; custom dispatchers can vary the
586      * response based on timing or the content of the request.
587      */
setDispatcher(Dispatcher dispatcher)588     public void setDispatcher(Dispatcher dispatcher) {
589         if (dispatcher == null) {
590             throw new NullPointerException();
591         }
592         this.dispatcher = dispatcher;
593     }
594 
595     /**
596      * An output stream that drops data after bodyLimit bytes.
597      */
598     private class TruncatingOutputStream extends ByteArrayOutputStream {
599         private int numBytesReceived = 0;
write(byte[] buffer, int offset, int len)600         @Override public void write(byte[] buffer, int offset, int len) {
601             numBytesReceived += len;
602             super.write(buffer, offset, Math.min(len, bodyLimit - count));
603         }
write(int oneByte)604         @Override public void write(int oneByte) {
605             numBytesReceived++;
606             if (count < bodyLimit) {
607                 super.write(oneByte);
608             }
609         }
610     }
611 
namedRunnable(final String name, final Runnable runnable)612     private static Runnable namedRunnable(final String name, final Runnable runnable) {
613         return new Runnable() {
614             public void run() {
615                 String originalName = Thread.currentThread().getName();
616                 Thread.currentThread().setName(name);
617                 try {
618                     runnable.run();
619                 } finally {
620                     Thread.currentThread().setName(originalName);
621                 }
622             }
623         };
624     }
625 }
626