1# Copyright 2014 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Routines to generate root and server certificates.
16
17Certificate Naming Conventions:
18  ca_cert:  crypto.X509 for the certificate authority (w/ both the pub &
19                priv keys)
20  cert:  a crypto.X509 certificate (w/ just the pub key)
21  cert_str:  a certificate string (w/ just the pub cert)
22  key:  a private crypto.PKey  (from ca or pem)
23  ca_cert_str:  a certificae authority string (w/ both the pub & priv certs)
24"""
25
26import logging
27import os
28import socket
29import time
30
31openssl_import_error = None
32
33Error = None
34SSL_METHOD = None
35SysCallError = None
36VERIFY_PEER = None
37ZeroReturnError = None
38FILETYPE_PEM = None
39
40try:
41  from OpenSSL import crypto, SSL
42
43  Error = SSL.Error
44  SSL_METHOD = SSL.SSLv23_METHOD
45  SysCallError = SSL.SysCallError
46  VERIFY_PEER = SSL.VERIFY_PEER
47  ZeroReturnError = SSL.ZeroReturnError
48  FILETYPE_PEM = crypto.FILETYPE_PEM
49except ImportError, e:
50  openssl_import_error = e
51
52
53def get_ssl_context(method=SSL_METHOD):
54  # One of: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD
55  if openssl_import_error:
56    raise openssl_import_error  # pylint: disable=raising-bad-type
57  return SSL.Context(method)
58
59
60class WrappedConnection(object):
61
62  def __init__(self, obj):
63    self._wrapped_obj = obj
64
65  def __getattr__(self, attr):
66    if attr in self.__dict__:
67      return getattr(self, attr)
68    return getattr(self._wrapped_obj, attr)
69
70  def recv(self, buflen=1024, flags=0):
71    try:
72      return self._wrapped_obj.recv(buflen, flags)
73    except SSL.SysCallError, e:
74      if e.args[1] == 'Unexpected EOF':
75        return ''
76      raise
77    except SSL.ZeroReturnError:
78      return ''
79
80
81def get_ssl_connection(context, connection):
82  return WrappedConnection(SSL.Connection(context, connection))
83
84
85def load_privatekey(key, filetype=FILETYPE_PEM):
86  """Loads obj private key object from string."""
87  return crypto.load_privatekey(filetype, key)
88
89
90def load_cert(cert_str, filetype=FILETYPE_PEM):
91  """Loads obj cert object from string."""
92  return crypto.load_certificate(filetype, cert_str)
93
94
95def _dump_privatekey(key, filetype=FILETYPE_PEM):
96  """Dumps obj private key object to string."""
97  return crypto.dump_privatekey(filetype, key)
98
99
100def _dump_cert(cert, filetype=FILETYPE_PEM):
101  """Dumps obj cert object to string."""
102  return crypto.dump_certificate(filetype, cert)
103
104
105def generate_dummy_ca_cert(subject='_WebPageReplayCert'):
106  """Generates dummy certificate authority.
107
108  Args:
109    subject: a string representing the desired root cert issuer
110  Returns:
111    A tuple of the public key and the private key strings for the root
112    certificate
113  """
114  if openssl_import_error:
115    raise openssl_import_error  # pylint: disable=raising-bad-type
116
117  key = crypto.PKey()
118  key.generate_key(crypto.TYPE_RSA, 1024)
119
120  ca_cert = crypto.X509()
121  ca_cert.set_serial_number(int(time.time()*10000))
122  ca_cert.set_version(2)
123  ca_cert.get_subject().CN = subject
124  ca_cert.get_subject().O = subject
125  ca_cert.gmtime_adj_notBefore(-60 * 60 * 24 * 365 * 2)
126  ca_cert.gmtime_adj_notAfter(60 * 60 * 24 * 365 * 2)
127  ca_cert.set_issuer(ca_cert.get_subject())
128  ca_cert.set_pubkey(key)
129  ca_cert.add_extensions([
130      crypto.X509Extension('basicConstraints', True, 'CA:TRUE'),
131      crypto.X509Extension('nsCertType', True, 'sslCA'),
132      crypto.X509Extension('extendedKeyUsage', True,
133                           ('serverAuth,clientAuth,emailProtection,'
134                            'timeStamping,msCodeInd,msCodeCom,msCTLSign,'
135                            'msSGC,msEFS,nsSGC')),
136      crypto.X509Extension('keyUsage', False, 'keyCertSign, cRLSign'),
137      crypto.X509Extension('subjectKeyIdentifier', False, 'hash',
138                           subject=ca_cert),
139      ])
140  ca_cert.sign(key, 'sha256')
141  key_str = _dump_privatekey(key)
142  ca_cert_str = _dump_cert(ca_cert)
143  return ca_cert_str, key_str
144
145
146def get_host_cert(host, port=443):
147  """Contacts the host and returns its certificate."""
148  host_certs = []
149  def verify_cb(conn, cert, errnum, depth, ok):
150    host_certs.append(cert)
151    # Return True to indicates that the certificate was ok.
152    return True
153
154  context = SSL.Context(SSL.SSLv23_METHOD)
155  context.set_verify(SSL.VERIFY_PEER, verify_cb)  # Demand a certificate
156  s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
157  connection = SSL.Connection(context, s)
158  try:
159    connection.connect((host, port))
160    connection.send('')
161  except SSL.SysCallError:
162    pass
163  except socket.gaierror:
164    logging.debug('Host name is not valid')
165  finally:
166    connection.shutdown()
167    connection.close()
168  if not host_certs:
169    logging.warning('Unable to get host certificate from %s:%s', host, port)
170    return ''
171  return _dump_cert(host_certs[-1])
172
173
174def write_dummy_ca_cert(ca_cert_str, key_str, cert_path):
175  """Writes four certificate files.
176
177  For example, if cert_path is "mycert.pem":
178      mycert.pem - CA plus private key
179      mycert-cert.pem - CA in PEM format
180      mycert-cert.cer - CA for Android
181      mycert-cert.p12 - CA in PKCS12 format for Windows devices
182  Args:
183    cert_path: path string such as "mycert.pem"
184    ca_cert_str: certificate string
185    key_str: private key string
186  """
187  dirname = os.path.dirname(cert_path)
188  if dirname and not os.path.exists(dirname):
189    os.makedirs(dirname)
190
191  root_path = os.path.splitext(cert_path)[0]
192  ca_cert_path = root_path + '-cert.pem'
193  android_cer_path = root_path + '-cert.cer'
194  windows_p12_path = root_path + '-cert.p12'
195
196  # Dump the CA plus private key
197  with open(cert_path, 'w') as f:
198    f.write(key_str)
199    f.write(ca_cert_str)
200
201  # Dump the certificate in PEM format
202  with open(ca_cert_path, 'w') as f:
203    f.write(ca_cert_str)
204
205  # Create a .cer file with the same contents for Android
206  with open(android_cer_path, 'w') as f:
207    f.write(ca_cert_str)
208
209  ca_cert = load_cert(ca_cert_str)
210  key = load_privatekey(key_str)
211  # Dump the certificate in PKCS12 format for Windows devices
212  with open(windows_p12_path, 'w') as f:
213    p12 = crypto.PKCS12()
214    p12.set_certificate(ca_cert)
215    p12.set_privatekey(key)
216    f.write(p12.export())
217
218
219def generate_cert(root_ca_cert_str, server_cert_str, server_host):
220  """Generates a cert_str with the sni field in server_cert_str signed by the
221  root_ca_cert_str.
222
223  Args:
224    root_ca_cert_str: PEM formatted string representing the root cert
225    server_cert_str: PEM formatted string representing cert
226    server_host: host name to use if there is no server_cert_str
227  Returns:
228    a PEM formatted certificate string
229  """
230  if openssl_import_error:
231    raise openssl_import_error  # pylint: disable=raising-bad-type
232
233  common_name = server_host
234  if server_cert_str:
235    cert = load_cert(server_cert_str)
236    common_name = cert.get_subject().commonName
237  else:
238    cert = crypto.X509()
239
240  ca_cert = load_cert(root_ca_cert_str)
241  key = load_privatekey(root_ca_cert_str)
242
243  req = crypto.X509Req()
244  req.get_subject().CN = common_name
245  req.set_pubkey(ca_cert.get_pubkey())
246  req.sign(key, 'sha256')
247
248  cert.gmtime_adj_notBefore(-60 * 60)
249  cert.gmtime_adj_notAfter(60 * 60 * 24 * 30)
250  cert.set_issuer(ca_cert.get_subject())
251  cert.set_subject(req.get_subject())
252  cert.set_serial_number(int(time.time()*10000))
253  cert.set_pubkey(req.get_pubkey())
254  cert.sign(key, 'sha256')
255
256  return _dump_cert(cert)
257