1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.security.net.config.cts;
18 
19 import static org.junit.Assert.assertEquals;
20 import static org.junit.Assert.assertTrue;
21 import static org.junit.Assert.fail;
22 
23 import android.app.DownloadManager;
24 import android.content.BroadcastReceiver;
25 import android.content.Context;
26 import android.content.Intent;
27 import android.content.IntentFilter;
28 import android.database.Cursor;
29 import android.net.Uri;
30 import android.os.SystemClock;
31 import android.security.net.config.cts.CtsNetSecConfigDownloadManagerTestCases.R;
32 import android.text.format.DateUtils;
33 
34 import androidx.test.runner.AndroidJUnit4;
35 
36 import org.junit.Test;
37 import org.junit.runner.RunWith;
38 
39 import java.io.ByteArrayOutputStream;
40 import java.io.InputStream;
41 import java.net.ServerSocket;
42 import java.net.Socket;
43 import java.security.KeyFactory;
44 import java.security.KeyStore;
45 import java.security.PrivateKey;
46 import java.security.cert.Certificate;
47 import java.security.cert.CertificateFactory;
48 import java.security.cert.X509Certificate;
49 import java.security.spec.PKCS8EncodedKeySpec;
50 import java.util.Collection;
51 import java.util.HashSet;
52 import java.util.concurrent.Callable;
53 import java.util.concurrent.FutureTask;
54 import java.util.concurrent.TimeUnit;
55 import java.util.concurrent.TimeoutException;
56 
57 import javax.net.ssl.KeyManagerFactory;
58 import javax.net.ssl.SSLContext;
59 import javax.net.ssl.SSLServerSocket;
60 import javax.net.ssl.TrustManagerFactory;
61 
62 @RunWith(AndroidJUnit4.class)
63 public class DownloadManagerTest extends BaseTestCase {
64 
65     private static final String HTTP_RESPONSE =
66             "HTTP/1.0 200 OK\r\nContent-Type: text/plain\r\nContent-length: 5\r\n\r\nhello";
67     private static final long TIMEOUT = 3 * DateUtils.SECOND_IN_MILLIS;
68 
69     @Test
testConfigTrustedCaAccepted()70     public void testConfigTrustedCaAccepted() throws Exception {
71         SSLServerSocket serverSocket = bindTLSServer(R.raw.valid_chain, R.raw.test_key);
72         runDownloadManagerTest(serverSocket, true);
73     }
74 
75     @Test
testUntrustedCaRejected()76     public void testUntrustedCaRejected() throws Exception {
77         try {
78             SSLServerSocket serverSocket = bindTLSServer(R.raw.invalid_chain, R.raw.test_key);
79             runDownloadManagerTest(serverSocket, true);
80             fail("Invalid CA should be rejected");
81         } catch (Exception expected) {
82         }
83     }
84 
85     @Test
testPerDomainCleartextAccepted()86     public void testPerDomainCleartextAccepted() throws Exception {
87         ServerSocket serverSocket = new ServerSocket();
88         serverSocket.bind(null);
89         runDownloadManagerTest(serverSocket, false);
90     }
91 
runDownloadManagerTest(ServerSocket serverSocket, boolean https)92     private void runDownloadManagerTest(ServerSocket serverSocket, boolean https) throws Exception {
93         DownloadManager dm =  mContext.getSystemService(DownloadManager.class);
94         DownloadCompleteReceiver receiver = new DownloadCompleteReceiver();
95         FutureTask<Void> serverFuture = new FutureTask<Void>(new Callable() {
96             @Override
97             public Void call() throws Exception {
98                 runServer(serverSocket);
99                 return null;
100             }
101         });
102         try {
103             IntentFilter filter = new IntentFilter(DownloadManager.ACTION_DOWNLOAD_COMPLETE);
104             mContext.registerReceiver(receiver, filter, Context.RECEIVER_EXPORTED);
105             new Thread(serverFuture).start();
106             String host = (https ? "https" : "http") + "://localhost";
107             Uri destination = Uri.parse(host + ":" + serverSocket.getLocalPort());
108             long id = dm.enqueue(new DownloadManager.Request(destination));
109             try {
110                 serverFuture.get(TIMEOUT, TimeUnit.MILLISECONDS);
111                 // Check that the download was successful.
112                 receiver.waitForDownloadComplete(TIMEOUT, id);
113                 assertSuccessfulDownload(id);
114             } catch (InterruptedException e) {
115                 // Wrap InterruptedException since otherwise it gets eaten by AndroidTest
116                 throw new RuntimeException(e);
117             } finally {
118                 dm.remove(id);
119             }
120         } finally {
121             mContext.unregisterReceiver(receiver);
122             serverFuture.cancel(true);
123             try {
124                 serverSocket.close();
125             } catch (Exception ignored) {}
126         }
127     }
128 
runServer(ServerSocket server)129     private void runServer(ServerSocket server) throws Exception {
130         Socket s = server.accept();
131         s.getOutputStream().write(HTTP_RESPONSE.getBytes());
132         s.getOutputStream().flush();
133         s.close();
134     }
135 
bindTLSServer(int chainResId, int keyResId)136     private SSLServerSocket bindTLSServer(int chainResId, int keyResId) throws Exception {
137         // Load certificate chain.
138         CertificateFactory fact = CertificateFactory.getInstance("X.509");
139         Collection<? extends Certificate> certs;
140         try (InputStream is = mContext.getResources().openRawResource(chainResId)) {
141             certs = fact.generateCertificates(is);
142         }
143         X509Certificate[] chain = new X509Certificate[certs.size()];
144         int i = 0;
145         for (Certificate cert : certs) {
146             chain[i++] = (X509Certificate) cert;
147         }
148 
149         // Load private key for the leaf.
150         PrivateKey key;
151         try (InputStream is = mContext.getResources().openRawResource(keyResId)) {
152             ByteArrayOutputStream keyout = new ByteArrayOutputStream();
153             byte[] buffer = new byte[4096];
154             int chunk_size;
155             while ((chunk_size = is.read(buffer)) != -1) {
156                 keyout.write(buffer, 0, chunk_size);
157             }
158             is.close();
159             byte[] keyBytes = keyout.toByteArray();
160             key = KeyFactory.getInstance("RSA")
161                     .generatePrivate(new PKCS8EncodedKeySpec(keyBytes));
162         }
163 
164         // Create KeyStore based on the private key/chain.
165         KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
166         ks.load(null);
167         ks.setKeyEntry("name", key, null, chain);
168 
169         // Create SSLContext.
170         TrustManagerFactory tmf = TrustManagerFactory.getInstance("PKIX");
171         tmf.init(ks);
172         KeyManagerFactory kmf =
173                 KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
174         kmf.init(ks, null);
175         SSLContext context = SSLContext.getInstance("TLS");
176         context.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
177 
178         SSLServerSocket s = (SSLServerSocket) context.getServerSocketFactory().createServerSocket();
179         s.bind(null);
180         return s;
181     }
182 
assertSuccessfulDownload(long id)183     private void assertSuccessfulDownload(long id) throws Exception {
184         Cursor cursor = null;
185         DownloadManager dm = mContext.getSystemService(DownloadManager.class);
186         try {
187             cursor = dm.query(new DownloadManager.Query().setFilterById(id));
188             assertTrue(cursor.moveToNext());
189             assertEquals(DownloadManager.STATUS_SUCCESSFUL, cursor.getInt(
190                     cursor.getColumnIndex(DownloadManager.COLUMN_STATUS)));
191         } finally {
192             if (cursor != null) {
193                 cursor.close();
194             }
195         }
196     }
197 
198     private static final class DownloadCompleteReceiver extends BroadcastReceiver {
199         private HashSet<Long> mCompletedDownloads = new HashSet<>();
200 
DownloadCompleteReceiver()201         public DownloadCompleteReceiver() {
202         }
203 
204         @Override
onReceive(Context context, Intent intent)205         public void onReceive(Context context, Intent intent) {
206             synchronized(mCompletedDownloads) {
207                 mCompletedDownloads.add(intent.getLongExtra(DownloadManager.EXTRA_DOWNLOAD_ID, -1));
208                 mCompletedDownloads.notifyAll();
209             }
210         }
211 
waitForDownloadComplete(long timeout, long id)212         public void waitForDownloadComplete(long timeout, long id)
213                 throws TimeoutException, InterruptedException  {
214             long deadline = SystemClock.elapsedRealtime() + timeout;
215             do {
216                 synchronized (mCompletedDownloads) {
217                     long millisTillTimeout = deadline - SystemClock.elapsedRealtime();
218                     if (millisTillTimeout > 0) {
219                         mCompletedDownloads.wait(millisTillTimeout);
220                     }
221                     if (mCompletedDownloads.contains(id)) {
222                         return;
223                     }
224                 }
225             } while (SystemClock.elapsedRealtime() < deadline);
226 
227             throw new TimeoutException("Timed out waiting for download complete");
228         }
229     }
230 
231 
232 }
233