1# Copyright 2014 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Routines to generate root and server certificates.
16
17Certificate Naming Conventions:
18  ca_cert:  crypto.X509 for the certificate authority (w/ both the pub &
19                priv keys)
20  cert:  a crypto.X509 certificate (w/ just the pub key)
21  cert_str:  a certificate string (w/ just the pub cert)
22  key:  a private crypto.PKey  (from ca or pem)
23  ca_cert_str:  a certificae authority string (w/ both the pub & priv certs)
24"""
25
26import logging
27import os
28import platform
29import socket
30import subprocess
31import time
32
33openssl_import_error = None
34
35Error = None
36SSL_METHOD = None
37SysCallError = None
38VERIFY_PEER = None
39ZeroReturnError = None
40FILETYPE_PEM = None
41
42try:
43  from OpenSSL import crypto, SSL
44
45  Error = SSL.Error
46  SSL_METHOD = SSL.SSLv23_METHOD
47  SysCallError = SSL.SysCallError
48  VERIFY_PEER = SSL.VERIFY_PEER
49  ZeroReturnError = SSL.ZeroReturnError
50  FILETYPE_PEM = crypto.FILETYPE_PEM
51except ImportError, e:
52  openssl_import_error = e
53
54
55def get_ssl_context(method=SSL_METHOD):
56  # One of: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD
57  if openssl_import_error:
58    raise openssl_import_error  # pylint: disable=raising-bad-type
59  return SSL.Context(method)
60
61
62class WrappedConnection(object):
63
64  def __init__(self, obj):
65    self._wrapped_obj = obj
66
67  def __getattr__(self, attr):
68    if attr in self.__dict__:
69      return getattr(self, attr)
70    return getattr(self._wrapped_obj, attr)
71
72  def recv(self, buflen=1024, flags=0):
73    try:
74      return self._wrapped_obj.recv(buflen, flags)
75    except SSL.SysCallError, e:
76      if e.args[1] == 'Unexpected EOF':
77        return ''
78      raise
79    except SSL.ZeroReturnError:
80      return ''
81
82
83def get_ssl_connection(context, connection):
84  return WrappedConnection(SSL.Connection(context, connection))
85
86
87def load_privatekey(key, filetype=FILETYPE_PEM):
88  """Loads obj private key object from string."""
89  return crypto.load_privatekey(filetype, key)
90
91
92def load_cert(cert_str, filetype=FILETYPE_PEM):
93  """Loads obj cert object from string."""
94  return crypto.load_certificate(filetype, cert_str)
95
96
97def _dump_privatekey(key, filetype=FILETYPE_PEM):
98  """Dumps obj private key object to string."""
99  return crypto.dump_privatekey(filetype, key)
100
101
102def _dump_cert(cert, filetype=FILETYPE_PEM):
103  """Dumps obj cert object to string."""
104  return crypto.dump_certificate(filetype, cert)
105
106
107def generate_dummy_ca_cert(subject='_WebPageReplayCert'):
108  """Generates dummy certificate authority.
109
110  Args:
111    subject: a string representing the desired root cert issuer
112  Returns:
113    A tuple of the public key and the private key strings for the root
114    certificate
115  """
116  if openssl_import_error:
117    raise openssl_import_error  # pylint: disable=raising-bad-type
118
119  key = crypto.PKey()
120  key.generate_key(crypto.TYPE_RSA, 1024)
121
122  ca_cert = crypto.X509()
123  ca_cert.set_serial_number(int(time.time()*10000))
124  ca_cert.set_version(2)
125  ca_cert.get_subject().CN = subject
126  ca_cert.get_subject().O = subject
127  ca_cert.gmtime_adj_notBefore(-60 * 60 * 24 * 365 * 2)
128  ca_cert.gmtime_adj_notAfter(60 * 60 * 24 * 365 * 2)
129  ca_cert.set_issuer(ca_cert.get_subject())
130  ca_cert.set_pubkey(key)
131  ca_cert.add_extensions([
132      crypto.X509Extension('basicConstraints', True, 'CA:TRUE'),
133      crypto.X509Extension('subjectAltName', False, 'DNS:' + subject),
134      crypto.X509Extension('nsCertType', True, 'sslCA'),
135      crypto.X509Extension('extendedKeyUsage', True,
136                           ('serverAuth,clientAuth,emailProtection,'
137                            'timeStamping,msCodeInd,msCodeCom,msCTLSign,'
138                            'msSGC,msEFS,nsSGC')),
139      crypto.X509Extension('keyUsage', False, 'keyCertSign, cRLSign'),
140      crypto.X509Extension('subjectKeyIdentifier', False, 'hash',
141                           subject=ca_cert),
142      ])
143  ca_cert.sign(key, 'sha256')
144  key_str = _dump_privatekey(key)
145  ca_cert_str = _dump_cert(ca_cert)
146  return ca_cert_str, key_str
147
148
149def get_host_cert(host, port=443):
150  """Contacts the host and returns its certificate."""
151  host_certs = []
152  def verify_cb(conn, cert, errnum, depth, ok):
153    host_certs.append(cert)
154    # Return True to indicates that the certificate was ok.
155    return True
156
157  context = SSL.Context(SSL.SSLv23_METHOD)
158  context.set_verify(SSL.VERIFY_PEER, verify_cb)  # Demand a certificate
159  s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
160  connection = SSL.Connection(context, s)
161  try:
162    connection.connect((host, port))
163    connection.send('')
164  except SSL.SysCallError:
165    pass
166  except socket.gaierror:
167    logging.debug('Host name is not valid')
168  finally:
169    connection.shutdown()
170    connection.close()
171  if not host_certs:
172    logging.warning('Unable to get host certificate from %s:%s', host, port)
173    return ''
174  return _dump_cert(host_certs[-1])
175
176
177def write_dummy_ca_cert(ca_cert_str, key_str, cert_path):
178  """Writes four certificate files.
179
180  For example, if cert_path is "mycert.pem":
181      mycert.pem - CA plus private key
182      mycert-cert.pem - CA in PEM format
183      mycert-cert.cer - CA for Android
184      mycert-cert.p12 - CA in PKCS12 format for Windows devices
185  Args:
186    cert_path: path string such as "mycert.pem"
187    ca_cert_str: certificate string
188    key_str: private key string
189  """
190  dirname = os.path.dirname(cert_path)
191  if dirname and not os.path.exists(dirname):
192    os.makedirs(dirname)
193
194  root_path = os.path.splitext(cert_path)[0]
195  ca_cert_path = root_path + '-cert.pem'
196  android_cer_path = root_path + '-cert.cer'
197  windows_p12_path = root_path + '-cert.p12'
198
199  # Dump the CA plus private key
200  with open(cert_path, 'w') as f:
201    f.write(key_str)
202    f.write(ca_cert_str)
203
204  # Dump the certificate in PEM format
205  with open(ca_cert_path, 'w') as f:
206    f.write(ca_cert_str)
207
208  # Create a .cer file with the same contents for Android
209  with open(android_cer_path, 'w') as f:
210    f.write(ca_cert_str)
211
212  ca_cert = load_cert(ca_cert_str)
213  key = load_privatekey(key_str)
214  # Dump the certificate in PKCS12 format for Windows devices
215  with open(windows_p12_path, 'w') as f:
216    p12 = crypto.PKCS12()
217    p12.set_certificate(ca_cert)
218    p12.set_privatekey(key)
219    f.write(p12.export())
220
221
222def generate_cert(root_ca_cert_str, server_cert_str, server_host):
223  """Generates a cert_str with the sni field in server_cert_str signed by the
224  root_ca_cert_str.
225
226  Args:
227    root_ca_cert_str: PEM formatted string representing the root cert
228    server_cert_str: PEM formatted string representing cert
229    server_host: host name to use if there is no server_cert_str
230  Returns:
231    a PEM formatted certificate string
232  """
233  EXTENSION_WHITELIST = set(['subjectAltName'])
234
235  if openssl_import_error:
236    raise openssl_import_error  # pylint: disable=raising-bad-type
237
238  common_name = server_host
239  reused_extensions = []
240  if server_cert_str:
241    original_cert = load_cert(server_cert_str)
242    common_name = original_cert.get_subject().commonName
243    for i in xrange(original_cert.get_extension_count()):
244      original_cert_extension = original_cert.get_extension(i)
245      if original_cert_extension.get_short_name() in EXTENSION_WHITELIST:
246        reused_extensions.append(original_cert_extension)
247
248  ca_cert = load_cert(root_ca_cert_str)
249  ca_key = load_privatekey(root_ca_cert_str)
250
251  cert = crypto.X509()
252  cert.get_subject().CN = common_name
253  cert.gmtime_adj_notBefore(-60 * 60)
254  cert.gmtime_adj_notAfter(60 * 60 * 24 * 30)
255  cert.set_issuer(ca_cert.get_subject())
256  cert.set_serial_number(int(time.time()*10000))
257  cert.set_pubkey(ca_key)
258  cert.add_extensions(reused_extensions)
259  cert.sign(ca_key, 'sha256')
260
261  return _dump_cert(cert)
262
263
264def install_cert_in_nssdb(home_directory_path, certificate_path):
265  """Installs a certificate into the ~/.pki/nssdb database.
266
267  Args:
268    home_directory_path: Path of the home directory where to install
269    certificate_path: Path of a CA in PEM format
270  """
271  assert os.path.isdir(home_directory_path)
272  assert platform.system() == 'Linux', \
273      'SSL certification authority has only been tested for linux.'
274  if (os.path.abspath(home_directory_path) ==
275          os.path.abspath(os.environ['HOME'])):
276    raise Exception('Modifying $HOME/.pki/nssdb compromises your machine.')
277
278  cert_database_path = os.path.join(home_directory_path, '.pki', 'nssdb')
279  def certutil(args):
280    cmd = ['certutil', '--empty-password', '-d', 'sql:' + cert_database_path]
281    cmd.extend(args)
282    logging.info(subprocess.list2cmdline(cmd))
283    subprocess.check_call(cmd)
284
285  if not os.path.isdir(cert_database_path):
286    os.makedirs(cert_database_path)
287    certutil(['-N'])
288
289  certutil(['-A', '-t', 'PC,,', '-n', certificate_path, '-i', certificate_path])
290