1#  Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
2#
3#  Licensed under the Apache License, Version 2.0 (the "License");
4#  you may not use this file except in compliance with the License.
5#  You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9#  Unless required by applicable law or agreed to in writing, software
10#  distributed under the License is distributed on an "AS IS" BASIS,
11#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#  See the License for the specific language governing permissions and
13#  limitations under the License.
14
15"""Functions that load and write PEM-encoded files."""
16
17import base64
18import typing
19
20# Should either be ASCII strings or bytes.
21FlexiText = typing.Union[str, bytes]
22
23
24def _markers(pem_marker: FlexiText) -> typing.Tuple[bytes, bytes]:
25    """
26    Returns the start and end PEM markers, as bytes.
27    """
28
29    if not isinstance(pem_marker, bytes):
30        pem_marker = pem_marker.encode('ascii')
31
32    return (b'-----BEGIN ' + pem_marker + b'-----',
33            b'-----END ' + pem_marker + b'-----')
34
35
36def _pem_lines(contents: bytes, pem_start: bytes, pem_end: bytes) -> typing.Iterator[bytes]:
37    """Generator over PEM lines between pem_start and pem_end."""
38
39    in_pem_part = False
40    seen_pem_start = False
41
42    for line in contents.splitlines():
43        line = line.strip()
44
45        # Skip empty lines
46        if not line:
47            continue
48
49        # Handle start marker
50        if line == pem_start:
51            if in_pem_part:
52                raise ValueError('Seen start marker "%r" twice' % pem_start)
53
54            in_pem_part = True
55            seen_pem_start = True
56            continue
57
58        # Skip stuff before first marker
59        if not in_pem_part:
60            continue
61
62        # Handle end marker
63        if in_pem_part and line == pem_end:
64            in_pem_part = False
65            break
66
67        # Load fields
68        if b':' in line:
69            continue
70
71        yield line
72
73    # Do some sanity checks
74    if not seen_pem_start:
75        raise ValueError('No PEM start marker "%r" found' % pem_start)
76
77    if in_pem_part:
78        raise ValueError('No PEM end marker "%r" found' % pem_end)
79
80
81def load_pem(contents: FlexiText, pem_marker: FlexiText) -> bytes:
82    """Loads a PEM file.
83
84    :param contents: the contents of the file to interpret
85    :param pem_marker: the marker of the PEM content, such as 'RSA PRIVATE KEY'
86        when your file has '-----BEGIN RSA PRIVATE KEY-----' and
87        '-----END RSA PRIVATE KEY-----' markers.
88
89    :return: the base64-decoded content between the start and end markers.
90
91    @raise ValueError: when the content is invalid, for example when the start
92        marker cannot be found.
93
94    """
95
96    # We want bytes, not text. If it's text, it can be converted to ASCII bytes.
97    if not isinstance(contents, bytes):
98        contents = contents.encode('ascii')
99
100    (pem_start, pem_end) = _markers(pem_marker)
101    pem_lines = [line for line in _pem_lines(contents, pem_start, pem_end)]
102
103    # Base64-decode the contents
104    pem = b''.join(pem_lines)
105    return base64.standard_b64decode(pem)
106
107
108def save_pem(contents: bytes, pem_marker: FlexiText) -> bytes:
109    """Saves a PEM file.
110
111    :param contents: the contents to encode in PEM format
112    :param pem_marker: the marker of the PEM content, such as 'RSA PRIVATE KEY'
113        when your file has '-----BEGIN RSA PRIVATE KEY-----' and
114        '-----END RSA PRIVATE KEY-----' markers.
115
116    :return: the base64-encoded content between the start and end markers, as bytes.
117
118    """
119
120    (pem_start, pem_end) = _markers(pem_marker)
121
122    b64 = base64.standard_b64encode(contents).replace(b'\n', b'')
123    pem_lines = [pem_start]
124
125    for block_start in range(0, len(b64), 64):
126        block = b64[block_start:block_start + 64]
127        pem_lines.append(block)
128
129    pem_lines.append(pem_end)
130    pem_lines.append(b'')
131
132    return b'\n'.join(pem_lines)
133