1# Copyright 2016 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Pure Python crypto-related routines for oauth2client.
16
17Uses the ``rsa``, ``pyasn1`` and ``pyasn1_modules`` packages
18to parse PEM files storing PKCS#1 or PKCS#8 keys as well as
19certificates.
20"""
21
22from pyasn1.codec.der import decoder
23from pyasn1_modules import pem
24from pyasn1_modules.rfc2459 import Certificate
25from pyasn1_modules.rfc5208 import PrivateKeyInfo
26import rsa
27import six
28
29from oauth2client import _helpers
30
31
32_PKCS12_ERROR = r"""\
33PKCS12 format is not supported by the RSA library.
34Either install PyOpenSSL, or please convert .p12 format
35to .pem format:
36    $ cat key.p12 | \
37    >   openssl pkcs12 -nodes -nocerts -passin pass:notasecret | \
38    >   openssl rsa > key.pem
39"""
40
41_POW2 = (128, 64, 32, 16, 8, 4, 2, 1)
42_PKCS1_MARKER = ('-----BEGIN RSA PRIVATE KEY-----',
43                 '-----END RSA PRIVATE KEY-----')
44_PKCS8_MARKER = ('-----BEGIN PRIVATE KEY-----',
45                 '-----END PRIVATE KEY-----')
46_PKCS8_SPEC = PrivateKeyInfo()
47
48
49def _bit_list_to_bytes(bit_list):
50    """Converts an iterable of 1's and 0's to bytes.
51
52    Combines the list 8 at a time, treating each group of 8 bits
53    as a single byte.
54    """
55    num_bits = len(bit_list)
56    byte_vals = bytearray()
57    for start in six.moves.xrange(0, num_bits, 8):
58        curr_bits = bit_list[start:start + 8]
59        char_val = sum(val * digit
60                       for val, digit in zip(_POW2, curr_bits))
61        byte_vals.append(char_val)
62    return bytes(byte_vals)
63
64
65class RsaVerifier(object):
66    """Verifies the signature on a message.
67
68    Args:
69        pubkey: rsa.key.PublicKey (or equiv), The public key to verify with.
70    """
71
72    def __init__(self, pubkey):
73        self._pubkey = pubkey
74
75    def verify(self, message, signature):
76        """Verifies a message against a signature.
77
78        Args:
79            message: string or bytes, The message to verify. If string, will be
80                     encoded to bytes as utf-8.
81            signature: string or bytes, The signature on the message. If
82                       string, will be encoded to bytes as utf-8.
83
84        Returns:
85            True if message was signed by the private key associated with the
86            public key that this object was constructed with.
87        """
88        message = _helpers._to_bytes(message, encoding='utf-8')
89        try:
90            return rsa.pkcs1.verify(message, signature, self._pubkey)
91        except (ValueError, rsa.pkcs1.VerificationError):
92            return False
93
94    @classmethod
95    def from_string(cls, key_pem, is_x509_cert):
96        """Construct an RsaVerifier instance from a string.
97
98        Args:
99            key_pem: string, public key in PEM format.
100            is_x509_cert: bool, True if key_pem is an X509 cert, otherwise it
101                          is expected to be an RSA key in PEM format.
102
103        Returns:
104            RsaVerifier instance.
105
106        Raises:
107            ValueError: if the key_pem can't be parsed. In either case, error
108                        will begin with 'No PEM start marker'. If
109                        ``is_x509_cert`` is True, will fail to find the
110                        "-----BEGIN CERTIFICATE-----" error, otherwise fails
111                        to find "-----BEGIN RSA PUBLIC KEY-----".
112        """
113        key_pem = _helpers._to_bytes(key_pem)
114        if is_x509_cert:
115            der = rsa.pem.load_pem(key_pem, 'CERTIFICATE')
116            asn1_cert, remaining = decoder.decode(der, asn1Spec=Certificate())
117            if remaining != b'':
118                raise ValueError('Unused bytes', remaining)
119
120            cert_info = asn1_cert['tbsCertificate']['subjectPublicKeyInfo']
121            key_bytes = _bit_list_to_bytes(cert_info['subjectPublicKey'])
122            pubkey = rsa.PublicKey.load_pkcs1(key_bytes, 'DER')
123        else:
124            pubkey = rsa.PublicKey.load_pkcs1(key_pem, 'PEM')
125        return cls(pubkey)
126
127
128class RsaSigner(object):
129    """Signs messages with a private key.
130
131    Args:
132        pkey: rsa.key.PrivateKey (or equiv), The private key to sign with.
133    """
134
135    def __init__(self, pkey):
136        self._key = pkey
137
138    def sign(self, message):
139        """Signs a message.
140
141        Args:
142            message: bytes, Message to be signed.
143
144        Returns:
145            string, The signature of the message for the given key.
146        """
147        message = _helpers._to_bytes(message, encoding='utf-8')
148        return rsa.pkcs1.sign(message, self._key, 'SHA-256')
149
150    @classmethod
151    def from_string(cls, key, password='notasecret'):
152        """Construct an RsaSigner instance from a string.
153
154        Args:
155            key: string, private key in PEM format.
156            password: string, password for private key file. Unused for PEM
157                      files.
158
159        Returns:
160            RsaSigner instance.
161
162        Raises:
163            ValueError if the key cannot be parsed as PKCS#1 or PKCS#8 in
164            PEM format.
165        """
166        key = _helpers._from_bytes(key)  # pem expects str in Py3
167        marker_id, key_bytes = pem.readPemBlocksFromFile(
168            six.StringIO(key), _PKCS1_MARKER, _PKCS8_MARKER)
169
170        if marker_id == 0:
171            pkey = rsa.key.PrivateKey.load_pkcs1(key_bytes,
172                                                 format='DER')
173        elif marker_id == 1:
174            key_info, remaining = decoder.decode(
175                key_bytes, asn1Spec=_PKCS8_SPEC)
176            if remaining != b'':
177                raise ValueError('Unused bytes', remaining)
178            pkey_info = key_info.getComponentByName('privateKey')
179            pkey = rsa.key.PrivateKey.load_pkcs1(pkey_info.asOctets(),
180                                                 format='DER')
181        else:
182            raise ValueError('No key could be detected.')
183
184        return cls(pkey)
185