1 package com.android.hotspot2.est;
2 
3 import android.net.Network;
4 import android.util.Base64;
5 import android.util.Log;
6 
7 import com.android.hotspot2.OMADMAdapter;
8 import com.android.hotspot2.asn1.Asn1Class;
9 import com.android.hotspot2.asn1.Asn1Constructed;
10 import com.android.hotspot2.asn1.Asn1Decoder;
11 import com.android.hotspot2.asn1.Asn1ID;
12 import com.android.hotspot2.asn1.Asn1Integer;
13 import com.android.hotspot2.asn1.Asn1Object;
14 import com.android.hotspot2.asn1.Asn1Oid;
15 import com.android.hotspot2.asn1.OidMappings;
16 import com.android.hotspot2.osu.HTTPHandler;
17 import com.android.hotspot2.osu.OSUSocketFactory;
18 import com.android.hotspot2.osu.commands.GetCertData;
19 import com.android.hotspot2.pps.HomeSP;
20 import com.android.hotspot2.utils.HTTPMessage;
21 import com.android.hotspot2.utils.HTTPResponse;
22 import com.android.org.bouncycastle.asn1.ASN1Encodable;
23 import com.android.org.bouncycastle.asn1.ASN1EncodableVector;
24 import com.android.org.bouncycastle.asn1.ASN1Set;
25 import com.android.org.bouncycastle.asn1.DERBitString;
26 import com.android.org.bouncycastle.asn1.DEREncodableVector;
27 import com.android.org.bouncycastle.asn1.DERIA5String;
28 import com.android.org.bouncycastle.asn1.DERObjectIdentifier;
29 import com.android.org.bouncycastle.asn1.DERPrintableString;
30 import com.android.org.bouncycastle.asn1.DERSet;
31 import com.android.org.bouncycastle.asn1.x509.Attribute;
32 import com.android.org.bouncycastle.jce.PKCS10CertificationRequest;
33 import com.android.org.bouncycastle.jce.spec.ECNamedCurveGenParameterSpec;
34 
35 import java.io.ByteArrayInputStream;
36 import java.io.IOException;
37 import java.net.URL;
38 import java.nio.ByteBuffer;
39 import java.nio.charset.StandardCharsets;
40 import java.security.AlgorithmParameters;
41 import java.security.GeneralSecurityException;
42 import java.security.KeyPair;
43 import java.security.KeyPairGenerator;
44 import java.security.KeyStore;
45 import java.security.PrivateKey;
46 import java.security.cert.CertificateFactory;
47 import java.security.cert.X509Certificate;
48 import java.util.ArrayList;
49 import java.util.Arrays;
50 import java.util.Collection;
51 import java.util.HashMap;
52 import java.util.HashSet;
53 import java.util.Iterator;
54 import java.util.List;
55 import java.util.Map;
56 import java.util.Set;
57 
58 import javax.net.ssl.KeyManager;
59 import javax.security.auth.x500.X500Principal;
60 
61 //import com.android.org.bouncycastle.jce.provider.BouncyCastleProvider;
62 
63 public class ESTHandler implements AutoCloseable {
64     private static final String TAG = "HS2EST";
65     private static final int MinRSAKeySize = 2048;
66 
67     private static final String CACERT_PATH = "/cacerts";
68     private static final String CSR_PATH = "/csrattrs";
69     private static final String SIMPLE_ENROLL_PATH = "/simpleenroll";
70     private static final String SIMPLE_REENROLL_PATH = "/simplereenroll";
71 
72     private final URL mURL;
73     private final String mUser;
74     private final byte[] mPassword;
75     private final OSUSocketFactory mSocketFactory;
76     private final OMADMAdapter mOMADMAdapter;
77 
78     private final List<X509Certificate> mCACerts = new ArrayList<>();
79     private final List<X509Certificate> mClientCerts = new ArrayList<>();
80     private PrivateKey mClientKey;
81 
ESTHandler(GetCertData certData, Network network, OMADMAdapter omadmAdapter, KeyManager km, KeyStore ks, HomeSP homeSP, int flowType)82     public ESTHandler(GetCertData certData, Network network, OMADMAdapter omadmAdapter,
83                       KeyManager km, KeyStore ks, HomeSP homeSP, int flowType)
84             throws IOException, GeneralSecurityException {
85         mURL = new URL(certData.getServer());
86         mUser = certData.getUserName();
87         mPassword = certData.getPassword();
88         mSocketFactory = OSUSocketFactory.getSocketFactory(ks, homeSP, flowType,
89                 network, mURL, km, true);
90         mOMADMAdapter = omadmAdapter;
91     }
92 
93     @Override
close()94     public void close() throws IOException {
95     }
96 
getCACerts()97     public List<X509Certificate> getCACerts() {
98         return mCACerts;
99     }
100 
getClientCerts()101     public List<X509Certificate> getClientCerts() {
102         return mClientCerts;
103     }
104 
getClientKey()105     public PrivateKey getClientKey() {
106         return mClientKey;
107     }
108 
indent(int amount)109     private static String indent(int amount) {
110         char[] indent = new char[amount * 2];
111         Arrays.fill(indent, ' ');
112         return new String(indent);
113     }
114 
execute(boolean reenroll)115     public void execute(boolean reenroll) throws IOException, GeneralSecurityException {
116         URL caURL = new URL(mURL.getProtocol(), mURL.getHost(), mURL.getPort(),
117                 mURL.getFile() + CACERT_PATH);
118 
119         HTTPResponse response;
120         try (HTTPHandler httpHandler = new HTTPHandler(StandardCharsets.ISO_8859_1, mSocketFactory,
121                 mUser, mPassword)) {
122             response = httpHandler.doGetHTTP(caURL);
123 
124             if (!"application/pkcs7-mime".equals(response.getHeaders().
125                     get(HTTPMessage.ContentTypeHeader))) {
126                 throw new IOException("Unexpected Content-Type: " +
127                         response.getHeaders().get(HTTPMessage.ContentTypeHeader));
128             }
129             ByteBuffer octetBuffer = response.getBinaryPayload();
130             Collection<Asn1Object> pkcs7Content1 = Asn1Decoder.decode(octetBuffer);
131             for (Asn1Object asn1Object : pkcs7Content1) {
132                 Log.d(TAG, "---");
133                 Log.d(TAG, asn1Object.toString());
134             }
135             Log.d(TAG, CACERT_PATH);
136 
137             mCACerts.addAll(unpackPkcs7(octetBuffer));
138             for (X509Certificate certificate : mCACerts) {
139                 Log.d(TAG, "CA-Cert: " + certificate.getSubjectX500Principal());
140             }
141 
142             /*
143             byte[] octets = new byte[octetBuffer.remaining()];
144             octetBuffer.duplicate().get(octets);
145             for (byte b : octets) {
146                 System.out.printf("%02x ", b & 0xff);
147             }
148             Log.d(TAG, );
149             */
150 
151             /* + BC
152             try {
153                 byte[] octets = new byte[octetBuffer.remaining()];
154                 octetBuffer.duplicate().get(octets);
155                 ASN1InputStream asnin = new ASN1InputStream(octets);
156                 for (int n = 0; n < 100; n++) {
157                     ASN1Primitive object = asnin.readObject();
158                     if (object == null) {
159                         break;
160                     }
161                     parseObject(object, 0);
162                 }
163             }
164             catch (Throwable t) {
165                 t.printStackTrace();
166             }
167 
168             Collection<Asn1Object> pkcs7Content = Asn1Decoder.decode(octetBuffer);
169             for (Asn1Object asn1Object : pkcs7Content) {
170                 Log.d(TAG, asn1Object);
171             }
172 
173             if (pkcs7Content.size() != 1) {
174                 throw new IOException("Unexpected pkcs 7 container: " + pkcs7Content.size());
175             }
176 
177             Asn1Constructed pkcs7Root = (Asn1Constructed) pkcs7Content.iterator().next();
178             Iterator<Asn1ID> certPath = Arrays.asList(Pkcs7CertPath).iterator();
179             Asn1Object certObject = pkcs7Root.findObject(certPath);
180             if (certObject == null || certPath.hasNext()) {
181                 throw new IOException("Failed to find cert; returned object " + certObject +
182                         ", path " + (certPath.hasNext() ? "short" : "exhausted"));
183             }
184 
185             ByteBuffer certOctets = certObject.getPayload();
186             if (certOctets == null) {
187                 throw new IOException("No cert payload in: " + certObject);
188             }
189 
190             byte[] certBytes = new byte[certOctets.remaining()];
191             certOctets.get(certBytes);
192 
193             CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
194             Certificate cert = certFactory.generateCertificate(new ByteArrayInputStream(certBytes));
195             Log.d(TAG, "EST Cert: " + cert);
196             */
197 
198             URL csrURL = new URL(mURL.getProtocol(), mURL.getHost(), mURL.getPort(),
199                     mURL.getFile() + CSR_PATH);
200             response = httpHandler.doGetHTTP(csrURL);
201 
202             octetBuffer = response.getBinaryPayload();
203             byte[] csrData = buildCSR(octetBuffer, mOMADMAdapter, httpHandler);
204 
205         /**/
206             Collection<Asn1Object> o = Asn1Decoder.decode(ByteBuffer.wrap(csrData));
207             Log.d(TAG, "CSR:");
208             Log.d(TAG, o.iterator().next().toString());
209             Log.d(TAG, "End CSR.");
210         /**/
211 
212             URL enrollURL = new URL(mURL.getProtocol(), mURL.getHost(), mURL.getPort(),
213                     mURL.getFile() + (reenroll ? SIMPLE_REENROLL_PATH : SIMPLE_ENROLL_PATH));
214             String data = Base64.encodeToString(csrData, Base64.DEFAULT);
215             octetBuffer = httpHandler.exchangeBinary(enrollURL, data, "application/pkcs10");
216 
217             Collection<Asn1Object> pkcs7Content2 = Asn1Decoder.decode(octetBuffer);
218             for (Asn1Object asn1Object : pkcs7Content2) {
219                 Log.d(TAG, "---");
220                 Log.d(TAG, asn1Object.toString());
221             }
222             mClientCerts.addAll(unpackPkcs7(octetBuffer));
223             for (X509Certificate cert : mClientCerts) {
224                 Log.d(TAG, cert.toString());
225             }
226         }
227     }
228 
229     private static final Asn1ID sSEQUENCE = new Asn1ID(Asn1Decoder.TAG_SEQ, Asn1Class.Universal);
230     private static final Asn1ID sCTXT0 = new Asn1ID(0, Asn1Class.Context);
231     private static final int PKCS7DataVersion = 1;
232     private static final int PKCS7SignedDataVersion = 3;
233 
unpackPkcs7(ByteBuffer pkcs7)234     private static List<X509Certificate> unpackPkcs7(ByteBuffer pkcs7)
235             throws IOException, GeneralSecurityException {
236         Collection<Asn1Object> pkcs7Content = Asn1Decoder.decode(pkcs7);
237 
238         if (pkcs7Content.size() != 1) {
239             throw new IOException("Unexpected pkcs 7 container: " + pkcs7Content.size());
240         }
241 
242         Asn1Object data = pkcs7Content.iterator().next();
243         if (!data.isConstructed() || !data.matches(sSEQUENCE)) {
244             throw new IOException("Expected SEQ OF, got " + data.toSimpleString());
245         } else if (data.getChildren().size() != 2) {
246             throw new IOException("Expected content info to have two children, got " +
247                     data.getChildren().size());
248         }
249 
250         Iterator<Asn1Object> children = data.getChildren().iterator();
251         Asn1Object contentType = children.next();
252         if (!contentType.equals(Asn1Oid.PKCS7SignedData)) {
253             throw new IOException("Content not PKCS7 signed data");
254         }
255         Asn1Object content = children.next();
256         if (!content.isConstructed() || !content.matches(sCTXT0)) {
257             throw new IOException("Expected [CONTEXT 0] with one child, got " +
258                     content.toSimpleString() + ", " + content.getChildren().size());
259         }
260 
261         Asn1Object signedData = content.getChildren().iterator().next();
262         Map<Integer, Asn1Object> itemMap = new HashMap<>();
263         for (Asn1Object item : signedData.getChildren()) {
264             if (itemMap.put(item.getTag(), item) != null && item.getTag() != Asn1Decoder.TAG_SET) {
265                 throw new IOException("Duplicate item in SignedData: " + item.toSimpleString());
266             }
267         }
268 
269         Asn1Object versionObject = itemMap.get(Asn1Decoder.TAG_INTEGER);
270         if (versionObject == null || !(versionObject instanceof Asn1Integer)) {
271             throw new IOException("Bad or missing PKCS7 version: " + versionObject);
272         }
273         int pkcs7version = (int) ((Asn1Integer) versionObject).getValue();
274         Asn1Object innerContentInfo = itemMap.get(Asn1Decoder.TAG_SEQ);
275         if (innerContentInfo == null ||
276                 !innerContentInfo.isConstructed() ||
277                 !innerContentInfo.matches(sSEQUENCE) ||
278                 innerContentInfo.getChildren().size() != 1) {
279             throw new IOException("Bad or missing PKCS7 contentInfo");
280         }
281         Asn1Object contentID = innerContentInfo.getChildren().iterator().next();
282         if (pkcs7version == PKCS7DataVersion && !contentID.equals(Asn1Oid.PKCS7Data) ||
283                 pkcs7version == PKCS7SignedDataVersion && !contentID.equals(Asn1Oid.PKCS7SignedData)) {
284             throw new IOException("Inner PKCS7 content (" + contentID +
285                     ") not expected for version " + pkcs7version);
286         }
287         Asn1Object certWrapper = itemMap.get(0);
288         if (certWrapper == null || !certWrapper.isConstructed() || !certWrapper.matches(sCTXT0)) {
289             throw new IOException("Expected [CONTEXT 0], got: " + certWrapper);
290         }
291 
292         List<X509Certificate> certList = new ArrayList<>(certWrapper.getChildren().size());
293         CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
294         for (Asn1Object certObject : certWrapper.getChildren()) {
295             ByteBuffer certOctets = ((Asn1Constructed) certObject).getEncoding();
296             if (certOctets == null) {
297                 throw new IOException("No cert payload in: " + certObject);
298             }
299             byte[] certBytes = new byte[certOctets.remaining()];
300             certOctets.get(certBytes);
301 
302             certList.add((X509Certificate) certFactory.
303                     generateCertificate(new ByteArrayInputStream(certBytes)));
304         }
305         return certList;
306     }
307 
buildCSR(ByteBuffer octetBuffer, OMADMAdapter omadmAdapter, HTTPHandler httpHandler)308     private byte[] buildCSR(ByteBuffer octetBuffer, OMADMAdapter omadmAdapter,
309                             HTTPHandler httpHandler) throws IOException, GeneralSecurityException {
310 
311         //Security.addProvider(new BouncyCastleProvider());
312 
313         Log.d(TAG, "/csrattrs:");
314         /*
315         byte[] octets = new byte[octetBuffer.remaining()];
316         octetBuffer.duplicate().get(octets);
317         for (byte b : octets) {
318             System.out.printf("%02x ", b & 0xff);
319         }
320         */
321         Collection<Asn1Object> csrs = Asn1Decoder.decode(octetBuffer);
322         for (Asn1Object asn1Object : csrs) {
323             Log.d(TAG, asn1Object.toString());
324         }
325 
326         if (csrs.size() != 1) {
327             throw new IOException("Unexpected object count in CSR attributes response: " +
328                     csrs.size());
329         }
330         Asn1Object sequence = csrs.iterator().next();
331         if (sequence.getClass() != Asn1Constructed.class) {
332             throw new IOException("Unexpected CSR attribute container: " + sequence);
333         }
334 
335         String keyAlgo = null;
336         Asn1Oid keyAlgoOID = null;
337         String sigAlgo = null;
338         String curveName = null;
339         Asn1Oid pubCrypto = null;
340         int keySize = -1;
341         Map<Asn1Oid, ASN1Encodable> idAttributes = new HashMap<>();
342 
343         for (Asn1Object child : sequence.getChildren()) {
344             if (child.getTag() == Asn1Decoder.TAG_OID) {
345                 Asn1Oid oid = (Asn1Oid) child;
346                 OidMappings.SigEntry sigEntry = OidMappings.getSigEntry(oid);
347                 if (sigEntry != null) {
348                     sigAlgo = sigEntry.getSigAlgo();
349                     keyAlgoOID = sigEntry.getKeyAlgo();
350                     keyAlgo = OidMappings.getJCEName(keyAlgoOID);
351                 } else if (oid.equals(OidMappings.sPkcs9AtChallengePassword)) {
352                     byte[] tlsUnique = httpHandler.getTLSUnique();
353                     if (tlsUnique != null) {
354                         idAttributes.put(oid, new DERPrintableString(
355                                 Base64.encodeToString(tlsUnique, Base64.DEFAULT)));
356                     } else {
357                         Log.w(TAG, "Cannot retrieve TLS unique channel binding");
358                     }
359                 }
360             } else if (child.getTag() == Asn1Decoder.TAG_SEQ) {
361                 Asn1Oid oid = null;
362                 Set<Asn1Oid> oidValues = new HashSet<>();
363                 List<Asn1Object> values = new ArrayList<>();
364 
365                 for (Asn1Object attributeSeq : child.getChildren()) {
366                     if (attributeSeq.getTag() == Asn1Decoder.TAG_OID) {
367                         oid = (Asn1Oid) attributeSeq;
368                     } else if (attributeSeq.getTag() == Asn1Decoder.TAG_SET) {
369                         for (Asn1Object value : attributeSeq.getChildren()) {
370                             if (value.getTag() == Asn1Decoder.TAG_OID) {
371                                 oidValues.add((Asn1Oid) value);
372                             } else {
373                                 values.add(value);
374                             }
375                         }
376                     }
377                 }
378                 if (oid == null) {
379                     throw new IOException("Invalid attribute, no OID");
380                 }
381                 if (oid.equals(OidMappings.sExtensionRequest)) {
382                     for (Asn1Oid subOid : oidValues) {
383                         if (OidMappings.isIDAttribute(subOid)) {
384                             if (subOid.equals(OidMappings.sMAC)) {
385                                 idAttributes.put(subOid, new DERIA5String(omadmAdapter.getMAC()));
386                             } else if (subOid.equals(OidMappings.sIMEI)) {
387                                 idAttributes.put(subOid, new DERIA5String(omadmAdapter.getImei()));
388                             } else if (subOid.equals(OidMappings.sMEID)) {
389                                 idAttributes.put(subOid, new DERBitString(omadmAdapter.getMeid()));
390                             } else if (subOid.equals(OidMappings.sDevID)) {
391                                 idAttributes.put(subOid,
392                                         new DERPrintableString(omadmAdapter.getDevID()));
393                             }
394                         }
395                     }
396                 } else if (OidMappings.getCryptoID(oid) != null) {
397                     pubCrypto = oid;
398                     if (!values.isEmpty()) {
399                         for (Asn1Object value : values) {
400                             if (value.getTag() == Asn1Decoder.TAG_INTEGER) {
401                                 keySize = (int) ((Asn1Integer) value).getValue();
402                             }
403                         }
404                     }
405                     if (oid.equals(OidMappings.sAlgo_EC)) {
406                         if (oidValues.isEmpty()) {
407                             throw new IOException("No ECC curve name provided");
408                         }
409                         for (Asn1Oid value : oidValues) {
410                             curveName = OidMappings.getJCEName(value);
411                             if (curveName != null) {
412                                 break;
413                             }
414                         }
415                         if (curveName == null) {
416                             throw new IOException("Found no ECC curve for " + oidValues);
417                         }
418                     }
419                 }
420             }
421         }
422 
423         if (keyAlgoOID == null) {
424             throw new IOException("No public key algorithm specified");
425         }
426         if (pubCrypto != null && !pubCrypto.equals(keyAlgoOID)) {
427             throw new IOException("Mismatching key algorithms");
428         }
429 
430         if (keyAlgoOID.equals(OidMappings.sAlgo_RSA)) {
431             if (keySize < MinRSAKeySize) {
432                 if (keySize >= 0) {
433                     Log.i(TAG, "Upgrading suggested RSA key size from " +
434                             keySize + " to " + MinRSAKeySize);
435                 }
436                 keySize = MinRSAKeySize;
437             }
438         }
439 
440         Log.d(TAG, String.format("pub key '%s', signature '%s', ECC curve '%s', id-atts %s",
441                 keyAlgo, sigAlgo, curveName, idAttributes));
442 
443         /*
444           Ruckus:
445             SEQUENCE:
446               OID=1.2.840.113549.1.1.11 (algo_id_sha256WithRSAEncryption)
447 
448           RFC-7030:
449             SEQUENCE:
450               OID=1.2.840.113549.1.9.7 (challengePassword)
451               SEQUENCE:
452                 OID=1.2.840.10045.2.1 (algo_id_ecPublicKey)
453                 SET:
454                   OID=1.3.132.0.34 (secp384r1)
455               SEQUENCE:
456                 OID=1.2.840.113549.1.9.14 (extensionRequest)
457                 SET:
458                   OID=1.3.6.1.1.1.1.22 (mac-address)
459               OID=1.2.840.10045.4.3.3 (eccdaWithSHA384)
460 
461               1L, 3L, 6L, 1L, 1L, 1L, 1L, 22
462          */
463 
464         // ECC Does not appear to be supported currently
465         KeyPairGenerator kpg = KeyPairGenerator.getInstance(keyAlgo);
466         if (curveName != null) {
467             AlgorithmParameters algorithmParameters = AlgorithmParameters.getInstance(keyAlgo);
468             algorithmParameters.init(new ECNamedCurveGenParameterSpec(curveName));
469             kpg.initialize(algorithmParameters
470                     .getParameterSpec(ECNamedCurveGenParameterSpec.class));
471         } else {
472             kpg.initialize(keySize);
473         }
474         KeyPair kp = kpg.generateKeyPair();
475 
476         X500Principal subject = new X500Principal("CN=Android, O=Google, C=US");
477 
478         mClientKey = kp.getPrivate();
479 
480         // !!! Map the idAttributes into an ASN1Set of values to pass to
481         // the PKCS10CertificationRequest - this code is using outdated BC classes and
482         // has *not* been tested.
483         ASN1Set attributes;
484         if (!idAttributes.isEmpty()) {
485             ASN1EncodableVector payload = new DEREncodableVector();
486             for (Map.Entry<Asn1Oid, ASN1Encodable> entry : idAttributes.entrySet()) {
487                 DERObjectIdentifier type = new DERObjectIdentifier(entry.getKey().toOIDString());
488                 ASN1Set values = new DERSet(entry.getValue());
489                 Attribute attribute = new Attribute(type, values);
490                 payload.add(attribute);
491             }
492             attributes = new DERSet(payload);
493         } else {
494             attributes = null;
495         }
496 
497         return new PKCS10CertificationRequest(sigAlgo, subject, kp.getPublic(),
498                 attributes, mClientKey).getEncoded();
499     }
500 }
501