1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Creates and manages token databases.
16
17This module manages reading tokenized strings from ELF files and building and
18maintaining token databases.
19"""
20
21import argparse
22from datetime import datetime
23import glob
24import json
25import logging
26import os
27from pathlib import Path
28import re
29import struct
30import sys
31from typing import (Any, Callable, Dict, Iterable, Iterator, List, Pattern,
32                    Set, TextIO, Tuple, Union)
33
34try:
35    from pw_tokenizer import elf_reader, tokens
36except ImportError:
37    # Append this path to the module search path to allow running this module
38    # without installing the pw_tokenizer package.
39    sys.path.append(os.path.dirname(os.path.dirname(
40        os.path.abspath(__file__))))
41    from pw_tokenizer import elf_reader, tokens
42
43_LOG = logging.getLogger('pw_tokenizer')
44
45
46def _elf_reader(elf) -> elf_reader.Elf:
47    return elf if isinstance(elf, elf_reader.Elf) else elf_reader.Elf(elf)
48
49
50# Magic number used to indicate the beginning of a tokenized string entry. This
51# value MUST match the value of _PW_TOKENIZER_ENTRY_MAGIC in
52# pw_tokenizer/public/pw_tokenizer/internal/tokenize_string.h.
53_TOKENIZED_ENTRY_MAGIC = 0xBAA98DEE
54_ENTRY = struct.Struct('<4I')
55_TOKENIZED_ENTRY_SECTIONS = re.compile(
56    r'^\.pw_tokenizer.entries(?:\.[_\d]+)?$')
57
58_LEGACY_STRING_SECTIONS = re.compile(
59    r'^\.pw_tokenized\.(?P<domain>[^.]+)(?:\.\d+)?$')
60
61_ERROR_HANDLER = 'surrogateescape'  # How to deal with UTF-8 decoding errors
62
63
64class Error(Exception):
65    """Failed to extract token entries from an ELF file."""
66
67
68def _read_tokenized_entries(
69        data: bytes,
70        domain: Pattern[str]) -> Iterator[tokens.TokenizedStringEntry]:
71    index = 0
72
73    while index + _ENTRY.size <= len(data):
74        magic, token, domain_len, string_len = _ENTRY.unpack_from(data, index)
75
76        if magic != _TOKENIZED_ENTRY_MAGIC:
77            raise Error(
78                f'Expected magic number 0x{_TOKENIZED_ENTRY_MAGIC:08x}, '
79                f'found 0x{magic:08x}')
80
81        start = index + _ENTRY.size
82        index = start + domain_len + string_len
83
84        # Create the entries, trimming null terminators.
85        entry = tokens.TokenizedStringEntry(
86            token,
87            data[start + domain_len:index - 1].decode(errors=_ERROR_HANDLER),
88            data[start:start + domain_len - 1].decode(errors=_ERROR_HANDLER),
89        )
90
91        if data[start + domain_len - 1] != 0:
92            raise Error(
93                f'Domain {entry.domain} for {entry.string} not null terminated'
94            )
95
96        if data[index - 1] != 0:
97            raise Error(f'String {entry.string} is not null terminated')
98
99        if domain.fullmatch(entry.domain):
100            yield entry
101
102
103def _read_tokenized_strings(sections: Dict[str, bytes],
104                            domain: Pattern[str]) -> Iterator[tokens.Database]:
105    # Legacy ELF files used "default" as the default domain instead of "". Remap
106    # the default if necessary.
107    if domain.pattern == tokens.DEFAULT_DOMAIN:
108        domain = re.compile('default')
109
110    for section, data in sections.items():
111        match = _LEGACY_STRING_SECTIONS.match(section)
112        if match and domain.match(match.group('domain')):
113            yield tokens.Database.from_strings(
114                (s.decode(errors=_ERROR_HANDLER) for s in data.split(b'\0')),
115                match.group('domain'))
116
117
118def _database_from_elf(elf, domain: Pattern[str]) -> tokens.Database:
119    """Reads the tokenized strings from an elf_reader.Elf or ELF file object."""
120    _LOG.debug('Reading tokenized strings in domain "%s" from %s', domain, elf)
121
122    reader = _elf_reader(elf)
123
124    # Read tokenized string entries.
125    section_data = reader.dump_section_contents(_TOKENIZED_ENTRY_SECTIONS)
126    if section_data is not None:
127        return tokens.Database(_read_tokenized_entries(section_data, domain))
128
129    # Read legacy null-terminated string entries.
130    sections = reader.dump_sections(_LEGACY_STRING_SECTIONS)
131    if sections:
132        return tokens.Database.merged(
133            *_read_tokenized_strings(sections, domain))
134
135    return tokens.Database([])
136
137
138def tokenization_domains(elf) -> Iterator[str]:
139    """Lists all tokenization domains in an ELF file."""
140    reader = _elf_reader(elf)
141    section_data = reader.dump_section_contents(_TOKENIZED_ENTRY_SECTIONS)
142    if section_data is not None:
143        yield from frozenset(
144            e.domain
145            for e in _read_tokenized_entries(section_data, re.compile('.*')))
146    else:  # Check for the legacy domain sections
147        for section in reader.sections:
148            match = _LEGACY_STRING_SECTIONS.match(section.name)
149            if match:
150                yield match.group('domain')
151
152
153def read_tokenizer_metadata(elf) -> Dict[str, int]:
154    """Reads the metadata entries from an ELF."""
155    sections = _elf_reader(elf).dump_section_contents(r'\.pw_tokenizer\.info')
156
157    metadata: Dict[str, int] = {}
158    if sections is not None:
159        for key, value in struct.iter_unpack('12sI', sections):
160            try:
161                metadata[key.rstrip(b'\0').decode()] = value
162            except UnicodeDecodeError as err:
163                _LOG.error('Failed to decode metadata key %r: %s',
164                           key.rstrip(b'\0'), err)
165
166    return metadata
167
168
169def _load_token_database(db, domain: Pattern[str]) -> tokens.Database:
170    """Loads a Database from a database object, ELF, CSV, or binary database."""
171    if db is None:
172        return tokens.Database()
173
174    if isinstance(db, tokens.Database):
175        return db
176
177    if isinstance(db, elf_reader.Elf):
178        return _database_from_elf(db, domain)
179
180    # If it's a str, it might be a path. Check if it's an ELF or CSV.
181    if isinstance(db, (str, Path)):
182        if not os.path.exists(db):
183            raise FileNotFoundError(
184                f'"{db}" is not a path to a token database')
185
186        # Read the path as an ELF file.
187        with open(db, 'rb') as fd:
188            if elf_reader.compatible_file(fd):
189                return _database_from_elf(fd, domain)
190
191        # Read the path as a packed binary or CSV file.
192        return tokens.DatabaseFile(db)
193
194    # Assume that it's a file object and check if it's an ELF.
195    if elf_reader.compatible_file(db):
196        return _database_from_elf(db, domain)
197
198    # Read the database as CSV or packed binary from a file object's path.
199    if hasattr(db, 'name') and os.path.exists(db.name):
200        return tokens.DatabaseFile(db.name)
201
202    # Read CSV directly from the file object.
203    return tokens.Database(tokens.parse_csv(db))
204
205
206def load_token_database(
207    *databases,
208    domain: Union[str,
209                  Pattern[str]] = tokens.DEFAULT_DOMAIN) -> tokens.Database:
210    """Loads a Database from database objects, ELFs, CSVs, or binary files."""
211    domain = re.compile(domain)
212    return tokens.Database.merged(*(_load_token_database(db, domain)
213                                    for db in databases))
214
215
216def database_summary(db: tokens.Database) -> Dict[str, Any]:
217    """Returns a simple report of properties of the database."""
218    present = [entry for entry in db.entries() if not entry.date_removed]
219    collisions = {
220        token: list(e.string for e in entries)
221        for token, entries in db.collisions()
222    }
223
224    # Add 1 to each string's size to account for the null terminator.
225    return dict(
226        present_entries=len(present),
227        present_size_bytes=sum(len(entry.string) + 1 for entry in present),
228        total_entries=len(db.entries()),
229        total_size_bytes=sum(len(entry.string) + 1 for entry in db.entries()),
230        collisions=collisions,
231    )
232
233
234_DatabaseReport = Dict[str, Dict[str, Dict[str, Any]]]
235
236
237def generate_reports(paths: Iterable[Path]) -> _DatabaseReport:
238    """Returns a dictionary with information about the provided databases."""
239    reports: _DatabaseReport = {}
240
241    for path in paths:
242        with path.open('rb') as file:
243            if elf_reader.compatible_file(file):
244                domains = list(tokenization_domains(file))
245            else:
246                domains = ['']
247
248        domain_reports = {}
249
250        for domain in domains:
251            domain_reports[domain] = database_summary(
252                load_token_database(path, domain=domain))
253
254        reports[str(path)] = domain_reports
255
256    return reports
257
258
259def _handle_create(databases, database, force, output_type, include, exclude,
260                   replace):
261    """Creates a token database file from one or more ELF files."""
262
263    if database == '-':
264        # Must write bytes to stdout; use sys.stdout.buffer.
265        fd = sys.stdout.buffer
266    elif not force and os.path.exists(database):
267        raise FileExistsError(
268            f'The file {database} already exists! Use --force to overwrite.')
269    else:
270        fd = open(database, 'wb')
271
272    database = tokens.Database.merged(*databases)
273    database.filter(include, exclude, replace)
274
275    with fd:
276        if output_type == 'csv':
277            tokens.write_csv(database, fd)
278        elif output_type == 'binary':
279            tokens.write_binary(database, fd)
280        else:
281            raise ValueError(f'Unknown database type "{output_type}"')
282
283    _LOG.info('Wrote database with %d entries to %s as %s', len(database),
284              fd.name, output_type)
285
286
287def _handle_add(token_database, databases):
288    initial = len(token_database)
289
290    for source in databases:
291        token_database.add(source.entries())
292
293    token_database.write_to_file()
294
295    _LOG.info('Added %d entries to %s',
296              len(token_database) - initial, token_database.path)
297
298
299def _handle_mark_removed(token_database, databases, date):
300    marked_removed = token_database.mark_removed(
301        (entry for entry in tokens.Database.merged(*databases).entries()
302         if not entry.date_removed), date)
303
304    token_database.write_to_file()
305
306    _LOG.info('Marked %d of %d entries as removed in %s', len(marked_removed),
307              len(token_database), token_database.path)
308
309
310def _handle_purge(token_database, before):
311    purged = token_database.purge(before)
312    token_database.write_to_file()
313
314    _LOG.info('Removed %d entries from %s', len(purged), token_database.path)
315
316
317def _handle_report(token_database_or_elf: List[Path], output: TextIO) -> None:
318    json.dump(generate_reports(token_database_or_elf), output, indent=2)
319    output.write('\n')
320
321
322def expand_paths_or_globs(*paths_or_globs: str) -> Iterable[Path]:
323    """Expands any globs in a list of paths; raises FileNotFoundError."""
324    for path_or_glob in paths_or_globs:
325        if os.path.exists(path_or_glob):
326            # This is a valid path; yield it without evaluating it as a glob.
327            yield Path(path_or_glob)
328        else:
329            paths = glob.glob(path_or_glob, recursive=True)
330
331            # If no paths were found and the path is not a glob, raise an Error.
332            if not paths and not any(c in path_or_glob for c in '*?[]!'):
333                raise FileNotFoundError(f'{path_or_glob} is not a valid path')
334
335            for path in paths:
336                # Resolve globs to CSV or compatible binary files.
337                if elf_reader.compatible_file(path) or path.endswith('.csv'):
338                    yield Path(path)
339
340
341class ExpandGlobs(argparse.Action):
342    """Argparse action that expands and appends paths."""
343    def __call__(self, parser, namespace, values, unused_option_string=None):
344        setattr(namespace, self.dest, list(expand_paths_or_globs(*values)))
345
346
347def _read_elf_with_domain(elf: str,
348                          domain: Pattern[str]) -> Iterable[tokens.Database]:
349    for path in expand_paths_or_globs(elf):
350        with path.open('rb') as file:
351            if not elf_reader.compatible_file(file):
352                raise ValueError(f'{elf} is not an ELF file, '
353                                 f'but the "{domain}" domain was specified')
354
355            yield _database_from_elf(file, domain)
356
357
358class LoadTokenDatabases(argparse.Action):
359    """Argparse action that reads tokenize databases from paths or globs.
360
361    ELF files may have #domain appended to them to specify a tokenization domain
362    other than the default.
363    """
364    def __call__(self, parser, namespace, values, option_string=None):
365        databases: List[tokens.Database] = []
366        paths: Set[Path] = set()
367
368        try:
369            for value in values:
370                if value.count('#') == 1:
371                    path, domain = value.split('#')
372                    domain = re.compile(domain)
373                    databases.extend(_read_elf_with_domain(path, domain))
374                else:
375                    paths.update(expand_paths_or_globs(value))
376
377            for path in paths:
378                databases.append(load_token_database(path))
379        except tokens.DatabaseFormatError as err:
380            parser.error(
381                f'argument elf_or_token_database: {path} is not a supported '
382                'token database file. Only ELF files or token databases (CSV '
383                f'or binary format) are supported. {err}. ')
384        except FileNotFoundError as err:
385            parser.error(f'argument elf_or_token_database: {err}')
386        except:  # pylint: disable=bare-except
387            _LOG.exception('Failed to load token database %s', path)
388            parser.error('argument elf_or_token_database: '
389                         f'Error occurred while loading token database {path}')
390
391        setattr(namespace, self.dest, databases)
392
393
394def token_databases_parser(nargs: str = '+') -> argparse.ArgumentParser:
395    """Returns an argument parser for reading token databases.
396
397    These arguments can be added to another parser using the parents arg.
398    """
399    parser = argparse.ArgumentParser(add_help=False)
400    parser.add_argument(
401        'databases',
402        metavar='elf_or_token_database',
403        nargs=nargs,
404        action=LoadTokenDatabases,
405        help=('ELF or token database files from which to read strings and '
406              'tokens. For ELF files, the tokenization domain to read from '
407              'may specified after the path as #domain_name (e.g. '
408              'foo.elf#TEST_DOMAIN). Unless specified, only the default '
409              'domain ("") is read from ELF files; .* reads all domains. '
410              'Globs are expanded to compatible database files.'))
411    return parser
412
413
414def _parse_args():
415    """Parse and return command line arguments."""
416    def year_month_day(value) -> datetime:
417        if value == 'today':
418            return datetime.now()
419
420        return datetime.strptime(value, tokens.DATE_FORMAT)
421
422    year_month_day.__name__ = 'year-month-day (YYYY-MM-DD)'
423
424    # Shared command line options.
425    option_db = argparse.ArgumentParser(add_help=False)
426    option_db.add_argument('-d',
427                           '--database',
428                           dest='token_database',
429                           type=tokens.DatabaseFile,
430                           required=True,
431                           help='The database file to update.')
432
433    option_tokens = token_databases_parser('*')
434
435    # Top-level argument parser.
436    parser = argparse.ArgumentParser(
437        description=__doc__,
438        formatter_class=argparse.RawDescriptionHelpFormatter)
439    parser.set_defaults(handler=lambda **_: parser.print_help())
440
441    subparsers = parser.add_subparsers(
442        help='Tokenized string database management actions:')
443
444    # The 'create' command creates a database file.
445    subparser = subparsers.add_parser(
446        'create',
447        parents=[option_tokens],
448        help=
449        'Creates a database with tokenized strings from one or more sources.')
450    subparser.set_defaults(handler=_handle_create)
451    subparser.add_argument(
452        '-d',
453        '--database',
454        required=True,
455        help='Path to the database file to create; use - for stdout.')
456    subparser.add_argument(
457        '-t',
458        '--type',
459        dest='output_type',
460        choices=('csv', 'binary'),
461        default='csv',
462        help='Which type of database to create. (default: csv)')
463    subparser.add_argument('-f',
464                           '--force',
465                           action='store_true',
466                           help='Overwrite the database if it exists.')
467    subparser.add_argument(
468        '-i',
469        '--include',
470        type=re.compile,
471        default=[],
472        action='append',
473        help=('If provided, at least one of these regular expressions must '
474              'match for a string to be included in the database.'))
475    subparser.add_argument(
476        '-e',
477        '--exclude',
478        type=re.compile,
479        default=[],
480        action='append',
481        help=('If provided, none of these regular expressions may match for a '
482              'string to be included in the database.'))
483
484    unescaped_slash = re.compile(r'(?<!\\)/')
485
486    def replacement(value: str) -> Tuple[Pattern, 'str']:
487        try:
488            find, sub = unescaped_slash.split(value, 1)
489        except ValueError as err:
490            raise argparse.ArgumentTypeError(
491                'replacements must be specified as "search_regex/replacement"')
492
493        try:
494            return re.compile(find.replace(r'\/', '/')), sub
495        except re.error as err:
496            raise argparse.ArgumentTypeError(
497                f'"{value}" is not a valid regular expression: {err}')
498
499    subparser.add_argument(
500        '--replace',
501        type=replacement,
502        default=[],
503        action='append',
504        help=('If provided, replaces text that matches a regular expression. '
505              'This can be used to replace sensitive terms in a token '
506              'database that will be distributed publicly. The expression and '
507              'replacement are specified as "search_regex/replacement". '
508              'Plain slash characters in the regex must be escaped with a '
509              r'backslash (\/). The replacement text may include '
510              'backreferences for captured groups in the regex.'))
511
512    # The 'add' command adds strings to a database from a set of ELFs.
513    subparser = subparsers.add_parser(
514        'add',
515        parents=[option_db, option_tokens],
516        help=(
517            'Adds new strings to a database with tokenized strings from a set '
518            'of ELF files or other token databases. Missing entries are NOT '
519            'marked as removed.'))
520    subparser.set_defaults(handler=_handle_add)
521
522    # The 'mark_removed' command marks removed entries to match a set of ELFs.
523    subparser = subparsers.add_parser(
524        'mark_removed',
525        parents=[option_db, option_tokens],
526        help=(
527            'Updates a database with tokenized strings from a set of strings. '
528            'Strings not present in the set remain in the database but are '
529            'marked as removed. New strings are NOT added.'))
530    subparser.set_defaults(handler=_handle_mark_removed)
531    subparser.add_argument(
532        '--date',
533        type=year_month_day,
534        help=('The removal date to use for all strings. '
535              'May be YYYY-MM-DD or "today". (default: today)'))
536
537    # The 'purge' command removes old entries.
538    subparser = subparsers.add_parser(
539        'purge',
540        parents=[option_db],
541        help='Purges removed strings from a database.')
542    subparser.set_defaults(handler=_handle_purge)
543    subparser.add_argument(
544        '-b',
545        '--before',
546        type=year_month_day,
547        help=('Delete all entries removed on or before this date. '
548              'May be YYYY-MM-DD or "today".'))
549
550    # The 'report' command prints a report about a database.
551    subparser = subparsers.add_parser('report',
552                                      help='Prints a report about a database.')
553    subparser.set_defaults(handler=_handle_report)
554    subparser.add_argument(
555        'token_database_or_elf',
556        nargs='+',
557        action=ExpandGlobs,
558        help='The ELF files or token databases about which to generate reports.'
559    )
560    subparser.add_argument(
561        '-o',
562        '--output',
563        type=argparse.FileType('w'),
564        default=sys.stdout,
565        help='The file to which to write the output; use - for stdout.')
566
567    args = parser.parse_args()
568
569    handler = args.handler
570    del args.handler
571
572    return handler, args
573
574
575def _init_logging(level: int) -> None:
576    _LOG.setLevel(logging.DEBUG)
577    log_to_stderr = logging.StreamHandler()
578    log_to_stderr.setLevel(level)
579    log_to_stderr.setFormatter(
580        logging.Formatter(
581            fmt='%(asctime)s.%(msecs)03d-%(levelname)s: %(message)s',
582            datefmt='%H:%M:%S'))
583
584    _LOG.addHandler(log_to_stderr)
585
586
587def _main(handler: Callable, args: argparse.Namespace) -> int:
588    _init_logging(logging.INFO)
589    handler(**vars(args))
590    return 0
591
592
593if __name__ == '__main__':
594    sys.exit(_main(*_parse_args()))
595