1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15r"""Decodes and detokenizes strings from binary or Base64 input.
16
17The main class provided by this module is the Detokenize class. To use it,
18construct it with the path to an ELF or CSV database, a tokens.Database,
19or a file object for an ELF file or CSV. Then, call the detokenize method with
20encoded messages, one at a time. The detokenize method returns a
21DetokenizedString object with the result.
22
23For example,
24
25  from pw_tokenizer import detokenize
26
27  detok = detokenize.Detokenizer('path/to/my/image.elf')
28  print(detok.detokenize(b'\x12\x34\x56\x78\x03hi!'))
29
30This module also provides a command line interface for decoding and detokenizing
31messages from a file or stdin.
32"""
33
34import argparse
35import base64
36import binascii
37from datetime import datetime
38import io
39import logging
40import os
41from pathlib import Path
42import re
43import string
44import struct
45import sys
46import time
47from typing import (BinaryIO, Callable, Dict, List, Iterable, Iterator, Match,
48                    NamedTuple, Optional, Pattern, Tuple, Union)
49
50try:
51    from pw_tokenizer import database, decode, encode, tokens
52except ImportError:
53    # Append this path to the module search path to allow running this module
54    # without installing the pw_tokenizer package.
55    sys.path.append(os.path.dirname(os.path.dirname(
56        os.path.abspath(__file__))))
57    from pw_tokenizer import database, decode, encode, tokens
58
59ENCODED_TOKEN = struct.Struct('<I')
60_LOG = logging.getLogger('pw_tokenizer')
61
62
63class DetokenizedString:
64    """A detokenized string, with all results if there are collisions."""
65    def __init__(self,
66                 token: Optional[int],
67                 format_string_entries: Iterable[tuple],
68                 encoded_message: bytes,
69                 show_errors: bool = False):
70        self.token = token
71        self.encoded_message = encoded_message
72        self._show_errors = show_errors
73
74        self.successes: List[decode.FormattedString] = []
75        self.failures: List[decode.FormattedString] = []
76
77        decode_attempts: List[Tuple[Tuple, decode.FormattedString]] = []
78
79        for entry, fmt in format_string_entries:
80            result = fmt.format(encoded_message[ENCODED_TOKEN.size:],
81                                show_errors)
82
83            # Sort competing entries so the most likely matches appear first.
84            # Decoded strings are prioritized by whether they
85            #
86            #   1. decoded all bytes for all arguments without errors,
87            #   2. decoded all data,
88            #   3. have the fewest decoding errors,
89            #   4. decoded the most arguments successfully, or
90            #   5. have the most recent removal date, if they were removed.
91            #
92            # This must match the collision resolution logic in detokenize.cc.
93            score: Tuple = (
94                all(arg.ok() for arg in result.args) and not result.remaining,
95                not result.remaining,  # decoded all data
96                -sum(not arg.ok() for arg in result.args),  # fewest errors
97                len(result.args),  # decoded the most arguments
98                entry.date_removed or datetime.max)  # most recently present
99
100            decode_attempts.append((score, result))
101
102        # Sort the attempts by the score so the most likely results are first.
103        decode_attempts.sort(key=lambda value: value[0], reverse=True)
104
105        # Split out the successesful decodes from the failures.
106        for score, result in decode_attempts:
107            if score[0]:
108                self.successes.append(result)
109            else:
110                self.failures.append(result)
111
112    def ok(self) -> bool:
113        """True if exactly one string decoded the arguments successfully."""
114        return len(self.successes) == 1
115
116    def matches(self) -> List[decode.FormattedString]:
117        """Returns the strings that matched the token, best matches first."""
118        return self.successes + self.failures
119
120    def best_result(self) -> Optional[decode.FormattedString]:
121        """Returns the string and args for the most likely decoded string."""
122        for string_and_args in self.matches():
123            return string_and_args
124
125        return None
126
127    def error_message(self) -> str:
128        """If detokenization failed, returns a descriptive message."""
129        if self.ok():
130            return ''
131
132        if not self.matches():
133            if self.token is None:
134                return 'missing token'
135
136            return 'unknown token {:08x}'.format(self.token)
137
138        if len(self.matches()) == 1:
139            return 'decoding failed for {!r}'.format(self.matches()[0].value)
140
141        return '{} matches'.format(len(self.matches()))
142
143    def __str__(self) -> str:
144        """Returns the string for the most likely result."""
145        result = self.best_result()
146        if result:
147            return result[0]
148
149        if self._show_errors:
150            return '<[ERROR: {}|{!r}]>'.format(self.error_message(),
151                                               self.encoded_message)
152
153        # Display the string as prefixed Base64 if it cannot be decoded.
154        return encode.prefixed_base64(self.encoded_message)
155
156    def __repr__(self) -> str:
157        if self.ok():
158            message = repr(str(self))
159        else:
160            message = 'ERROR: {}|{!r}'.format(self.error_message(),
161                                              self.encoded_message)
162
163        return '{}({})'.format(type(self).__name__, message)
164
165
166class _TokenizedFormatString(NamedTuple):
167    entry: tokens.TokenizedStringEntry
168    format: decode.FormatString
169
170
171class Detokenizer:
172    """Main detokenization class; detokenizes strings and caches results."""
173    def __init__(self, *token_database_or_elf, show_errors: bool = False):
174        """Decodes and detokenizes binary messages.
175
176        Args:
177          *token_database_or_elf: a path or file object for an ELF or CSV
178              database, a tokens.Database, or an elf_reader.Elf
179          show_errors: if True, an error message is used in place of the %
180              conversion specifier when an argument fails to decode
181        """
182        self.database = database.load_token_database(*token_database_or_elf)
183        self.show_errors = show_errors
184
185        # Cache FormatStrings for faster lookup & formatting.
186        self._cache: Dict[int, List[_TokenizedFormatString]] = {}
187
188    def lookup(self, token: int) -> List[_TokenizedFormatString]:
189        """Returns (TokenizedStringEntry, FormatString) list for matches."""
190        try:
191            return self._cache[token]
192        except KeyError:
193            format_strings = [
194                _TokenizedFormatString(entry, decode.FormatString(str(entry)))
195                for entry in self.database.token_to_entries[token]
196            ]
197            self._cache[token] = format_strings
198            return format_strings
199
200    def detokenize(self, encoded_message: bytes) -> DetokenizedString:
201        """Decodes and detokenizes a message as a DetokenizedString."""
202        if len(encoded_message) < ENCODED_TOKEN.size:
203            return DetokenizedString(None, (), encoded_message,
204                                     self.show_errors)
205
206        token, = ENCODED_TOKEN.unpack_from(encoded_message)
207        return DetokenizedString(token, self.lookup(token), encoded_message,
208                                 self.show_errors)
209
210
211class AutoUpdatingDetokenizer:
212    """Loads and updates a detokenizer from database paths."""
213    class _DatabasePath:
214        """Tracks the modified time of a path or file object."""
215        def __init__(self, path):
216            self.path = path if isinstance(path, (str, Path)) else path.name
217            self._modified_time: Optional[float] = self._last_modified_time()
218
219        def updated(self) -> bool:
220            """True if the path has been updated since the last call."""
221            modified_time = self._last_modified_time()
222            if modified_time is None or modified_time == self._modified_time:
223                return False
224
225            self._modified_time = modified_time
226            return True
227
228        def _last_modified_time(self) -> Optional[float]:
229            try:
230                return os.path.getmtime(self.path)
231            except FileNotFoundError:
232                return None
233
234        def load(self) -> tokens.Database:
235            try:
236                return database.load_token_database(self.path)
237            except FileNotFoundError:
238                return database.load_token_database()
239
240    def __init__(self,
241                 *paths_or_files,
242                 min_poll_period_s: float = 1.0) -> None:
243        self.paths = tuple(self._DatabasePath(path) for path in paths_or_files)
244        self.min_poll_period_s = min_poll_period_s
245        self._last_checked_time: float = time.time()
246        self._detokenizer = Detokenizer(*(path.load() for path in self.paths))
247
248    def detokenize(self, data: bytes) -> DetokenizedString:
249        """Updates the token database if it has changed, then detokenizes."""
250        if time.time() - self._last_checked_time >= self.min_poll_period_s:
251            self._last_checked_time = time.time()
252
253            if any(path.updated() for path in self.paths):
254                _LOG.info('Changes detected; reloading token database')
255                self._detokenizer = Detokenizer(*(path.load()
256                                                  for path in self.paths))
257
258        return self._detokenizer.detokenize(data)
259
260
261_Detokenizer = Union[Detokenizer, AutoUpdatingDetokenizer]
262
263
264class PrefixedMessageDecoder:
265    """Parses messages that start with a prefix character from a byte stream."""
266    def __init__(self, prefix: Union[str, bytes], chars: Union[str, bytes]):
267        """Parses prefixed messages.
268
269        Args:
270          prefix: one character that signifies the start of a message
271          chars: characters allowed in a message
272        """
273        self._prefix = prefix.encode() if isinstance(prefix, str) else prefix
274
275        if isinstance(chars, str):
276            chars = chars.encode()
277
278        # Store the valid message bytes as a set of binary strings.
279        self._message_bytes = frozenset(chars[i:i + 1]
280                                        for i in range(len(chars)))
281
282        if len(self._prefix) != 1 or self._prefix in self._message_bytes:
283            raise ValueError(
284                'Invalid prefix {!r}: the prefix must be a single '
285                'character that is not a valid message character.'.format(
286                    prefix))
287
288        self.data = bytearray()
289
290    def _read_next(self, fd: BinaryIO) -> Tuple[bytes, int]:
291        """Returns the next character and its index."""
292        char = fd.read(1)
293        index = len(self.data)
294        self.data += char
295        return char, index
296
297    def read_messages(self,
298                      binary_fd: BinaryIO) -> Iterator[Tuple[bool, bytes]]:
299        """Parses prefixed messages; yields (is_message, contents) chunks."""
300        message_start = None
301
302        while True:
303            # This reads the file character-by-character. Non-message characters
304            # are yielded right away; message characters are grouped.
305            char, index = self._read_next(binary_fd)
306
307            # If in a message, keep reading until the message completes.
308            if message_start is not None:
309                if char in self._message_bytes:
310                    continue
311
312                yield True, self.data[message_start:index]
313                message_start = None
314
315            # Handle a non-message character.
316            if not char:
317                return
318
319            if char == self._prefix:
320                message_start = index
321            else:
322                yield False, char
323
324    def transform(self, binary_fd: BinaryIO,
325                  transform: Callable[[bytes], bytes]) -> Iterator[bytes]:
326        """Yields the file with a transformation applied to the messages."""
327        for is_message, chunk in self.read_messages(binary_fd):
328            yield transform(chunk) if is_message else chunk
329
330
331def _detokenize_prefixed_base64(
332        detokenizer: _Detokenizer, prefix: bytes,
333        recursion: int) -> Callable[[Match[bytes]], bytes]:
334    """Returns a function that decodes prefixed Base64 with the detokenizer."""
335    def decode_and_detokenize(match: Match[bytes]) -> bytes:
336        """Decodes prefixed base64 with the provided detokenizer."""
337        original = match.group(0)
338
339        try:
340            detokenized_string = detokenizer.detokenize(
341                base64.b64decode(original[1:], validate=True))
342            if detokenized_string.matches():
343                result = str(detokenized_string).encode()
344
345                if recursion > 0 and original != result:
346                    result = detokenize_base64(detokenizer, result, prefix,
347                                               recursion - 1)
348
349                return result
350        except binascii.Error:
351            pass
352
353        return original
354
355    return decode_and_detokenize
356
357
358BASE64_PREFIX = encode.BASE64_PREFIX.encode()
359DEFAULT_RECURSION = 9
360
361
362def _base64_message_regex(prefix: bytes) -> Pattern[bytes]:
363    """Returns a regular expression for prefixed base64 tokenized strings."""
364    return re.compile(
365        # Base64 tokenized strings start with the prefix character ($)
366        re.escape(prefix) + (
367            # Tokenized strings contain 0 or more blocks of four Base64 chars.
368            br'(?:[A-Za-z0-9+/\-_]{4})*'
369            # The last block of 4 chars may have one or two padding chars (=).
370            br'(?:[A-Za-z0-9+/\-_]{3}=|[A-Za-z0-9+/\-_]{2}==)?'))
371
372
373def detokenize_base64_live(detokenizer: _Detokenizer,
374                           input_file: BinaryIO,
375                           output: BinaryIO,
376                           prefix: Union[str, bytes] = BASE64_PREFIX,
377                           recursion: int = DEFAULT_RECURSION) -> None:
378    """Reads chars one-at-a-time and decodes messages; SLOW for big files."""
379    prefix_bytes = prefix.encode() if isinstance(prefix, str) else prefix
380
381    base64_message = _base64_message_regex(prefix_bytes)
382
383    def transform(data: bytes) -> bytes:
384        return base64_message.sub(
385            _detokenize_prefixed_base64(detokenizer, prefix_bytes, recursion),
386            data)
387
388    for message in PrefixedMessageDecoder(
389            prefix, string.ascii_letters + string.digits + '+/-_=').transform(
390                input_file, transform):
391        output.write(message)
392
393        # Flush each line to prevent delays when piping between processes.
394        if b'\n' in message:
395            output.flush()
396
397
398def detokenize_base64_to_file(detokenizer: _Detokenizer,
399                              data: bytes,
400                              output: BinaryIO,
401                              prefix: Union[str, bytes] = BASE64_PREFIX,
402                              recursion: int = DEFAULT_RECURSION) -> None:
403    """Decodes prefixed Base64 messages in data; decodes to an output file."""
404    prefix = prefix.encode() if isinstance(prefix, str) else prefix
405    output.write(
406        _base64_message_regex(prefix).sub(
407            _detokenize_prefixed_base64(detokenizer, prefix, recursion), data))
408
409
410def detokenize_base64(detokenizer: _Detokenizer,
411                      data: bytes,
412                      prefix: Union[str, bytes] = BASE64_PREFIX,
413                      recursion: int = DEFAULT_RECURSION) -> bytes:
414    """Decodes and replaces prefixed Base64 messages in the provided data.
415
416    Args:
417      detokenizer: the detokenizer with which to decode messages
418      data: the binary data to decode
419      prefix: one-character byte string that signals the start of a message
420      recursion: how many levels to recursively decode
421
422    Returns:
423      copy of the data with all recognized tokens decoded
424    """
425    output = io.BytesIO()
426    detokenize_base64_to_file(detokenizer, data, output, prefix, recursion)
427    return output.getvalue()
428
429
430def _follow_and_detokenize_file(detokenizer: _Detokenizer,
431                                file: BinaryIO,
432                                output: BinaryIO,
433                                prefix: Union[str, bytes],
434                                poll_period_s: float = 0.01) -> None:
435    """Polls a file to detokenize it and any appended data."""
436
437    try:
438        while True:
439            data = file.read()
440            if data:
441                detokenize_base64_to_file(detokenizer, data, output, prefix)
442                output.flush()
443            else:
444                time.sleep(poll_period_s)
445    except KeyboardInterrupt:
446        pass
447
448
449def _handle_base64(databases, input_file: BinaryIO, output: BinaryIO,
450                   prefix: str, show_errors: bool, follow: bool) -> None:
451    """Handles the base64 command line option."""
452    # argparse.FileType doesn't correctly handle - for binary files.
453    if input_file is sys.stdin:
454        input_file = sys.stdin.buffer
455
456    if output is sys.stdout:
457        output = sys.stdout.buffer
458
459    detokenizer = Detokenizer(tokens.Database.merged(*databases),
460                              show_errors=show_errors)
461
462    if follow:
463        _follow_and_detokenize_file(detokenizer, input_file, output, prefix)
464    elif input_file.seekable():
465        # Process seekable files all at once, which is MUCH faster.
466        detokenize_base64_to_file(detokenizer, input_file.read(), output,
467                                  prefix)
468    else:
469        # For non-seekable inputs (e.g. pipes), read one character at a time.
470        detokenize_base64_live(detokenizer, input_file, output, prefix)
471
472
473def _parse_args() -> argparse.Namespace:
474    """Parses and return command line arguments."""
475
476    parser = argparse.ArgumentParser(
477        description=__doc__,
478        formatter_class=argparse.RawDescriptionHelpFormatter)
479    parser.set_defaults(handler=lambda **_: parser.print_help())
480
481    subparsers = parser.add_subparsers(help='Encoding of the input.')
482
483    base64_help = 'Detokenize Base64-encoded data from a file or stdin.'
484    subparser = subparsers.add_parser(
485        'base64',
486        description=base64_help,
487        parents=[database.token_databases_parser()],
488        help=base64_help)
489    subparser.set_defaults(handler=_handle_base64)
490    subparser.add_argument(
491        '-i',
492        '--input',
493        dest='input_file',
494        type=argparse.FileType('rb'),
495        default=sys.stdin.buffer,
496        help='The file from which to read; provide - or omit for stdin.')
497    subparser.add_argument(
498        '-f',
499        '--follow',
500        action='store_true',
501        help=('Detokenize data appended to input_file as it grows; similar to '
502              'tail -f.'))
503    subparser.add_argument('-o',
504                           '--output',
505                           type=argparse.FileType('wb'),
506                           default=sys.stdout.buffer,
507                           help=('The file to which to write the output; '
508                                 'provide - or omit for stdout.'))
509    subparser.add_argument(
510        '-p',
511        '--prefix',
512        default=BASE64_PREFIX,
513        help=('The one-character prefix that signals the start of a '
514              'Base64-encoded message. (default: $)'))
515    subparser.add_argument(
516        '-s',
517        '--show_errors',
518        action='store_true',
519        help=('Show error messages instead of conversion specifiers when '
520              'arguments cannot be decoded.'))
521
522    return parser.parse_args()
523
524
525def main() -> int:
526    args = _parse_args()
527
528    handler = args.handler
529    del args.handler
530
531    handler(**vars(args))
532    return 0
533
534
535if __name__ == '__main__':
536    if sys.version_info[0] < 3:
537        sys.exit('ERROR: The detokenizer command line tools require Python 3.')
538    sys.exit(main())
539