1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Builds and manages databases of tokenized strings."""
15
16import collections
17import csv
18from dataclasses import dataclass
19from datetime import datetime
20import io
21import logging
22from pathlib import Path
23import re
24import struct
25from typing import (BinaryIO, Callable, Dict, Iterable, Iterator, List,
26                    NamedTuple, Optional, Pattern, Tuple, Union, ValuesView)
27
28DATE_FORMAT = '%Y-%m-%d'
29DEFAULT_DOMAIN = ''
30
31# The default hash length to use. This value only applies when hashing strings
32# from a legacy-style ELF with plain strings. New tokenized string entries
33# include the token alongside the string.
34#
35# This MUST match the default value of PW_TOKENIZER_CFG_C_HASH_LENGTH in
36# pw_tokenizer/public/pw_tokenizer/config.h.
37DEFAULT_C_HASH_LENGTH = 128
38
39TOKENIZER_HASH_CONSTANT = 65599
40
41_LOG = logging.getLogger('pw_tokenizer')
42
43
44def _value(char: Union[int, str]) -> int:
45    return char if isinstance(char, int) else ord(char)
46
47
48def pw_tokenizer_65599_fixed_length_hash(string: Union[str, bytes],
49                                         hash_length: int) -> int:
50    """Hashes the provided string.
51
52    This hash function is only used when adding tokens from legacy-style
53    tokenized strings in an ELF, which do not include the token.
54    """
55    hash_value = len(string)
56    coefficient = TOKENIZER_HASH_CONSTANT
57
58    for char in string[:hash_length]:
59        hash_value = (hash_value + coefficient * _value(char)) % 2**32
60        coefficient = (coefficient * TOKENIZER_HASH_CONSTANT) % 2**32
61
62    return hash_value
63
64
65def default_hash(string: Union[str, bytes]) -> int:
66    return pw_tokenizer_65599_fixed_length_hash(string, DEFAULT_C_HASH_LENGTH)
67
68
69class _EntryKey(NamedTuple):
70    """Uniquely refers to an entry."""
71    token: int
72    string: str
73
74
75@dataclass(eq=True, order=False)
76class TokenizedStringEntry:
77    """A tokenized string with its metadata."""
78    token: int
79    string: str
80    domain: str = DEFAULT_DOMAIN
81    date_removed: Optional[datetime] = None
82
83    def key(self) -> _EntryKey:
84        """The key determines uniqueness for a tokenized string."""
85        return _EntryKey(self.token, self.string)
86
87    def update_date_removed(self,
88                            new_date_removed: Optional[datetime]) -> None:
89        """Sets self.date_removed if the other date is newer."""
90        # No removal date (None) is treated as the newest date.
91        if self.date_removed is None:
92            return
93
94        if new_date_removed is None or new_date_removed > self.date_removed:
95            self.date_removed = new_date_removed
96
97    def __lt__(self, other) -> bool:
98        """Sorts the entry by token, date removed, then string."""
99        if self.token != other.token:
100            return self.token < other.token
101
102        # Sort removal dates in reverse, so the most recently removed (or still
103        # present) entry appears first.
104        if self.date_removed != other.date_removed:
105            return (other.date_removed or datetime.max) < (self.date_removed
106                                                           or datetime.max)
107
108        return self.string < other.string
109
110    def __str__(self) -> str:
111        return self.string
112
113
114class Database:
115    """Database of tokenized strings stored as TokenizedStringEntry objects."""
116    def __init__(self, entries: Iterable[TokenizedStringEntry] = ()):
117        """Creates a token database."""
118        # The database dict stores each unique (token, string) entry.
119        self._database: Dict[_EntryKey, TokenizedStringEntry] = {
120            entry.key(): entry
121            for entry in entries
122        }
123
124        # This is a cache for fast token lookup that is built as needed.
125        self._cache: Optional[Dict[int, List[TokenizedStringEntry]]] = None
126
127    @classmethod
128    def from_strings(
129            cls,
130            strings: Iterable[str],
131            domain: str = DEFAULT_DOMAIN,
132            tokenize: Callable[[str], int] = default_hash) -> 'Database':
133        """Creates a Database from an iterable of strings."""
134        return cls((TokenizedStringEntry(tokenize(string), string, domain)
135                    for string in strings))
136
137    @classmethod
138    def merged(cls, *databases: 'Database') -> 'Database':
139        """Creates a TokenDatabase from one or more other databases."""
140        db = cls()
141        db.merge(*databases)
142        return db
143
144    @property
145    def token_to_entries(self) -> Dict[int, List[TokenizedStringEntry]]:
146        """Returns a dict that maps tokens to a list of TokenizedStringEntry."""
147        if self._cache is None:  # build cache token -> entry cache
148            self._cache = collections.defaultdict(list)
149            for entry in self._database.values():
150                self._cache[entry.token].append(entry)
151
152        return self._cache
153
154    def entries(self) -> ValuesView[TokenizedStringEntry]:
155        """Returns iterable over all TokenizedStringEntries in the database."""
156        return self._database.values()
157
158    def collisions(self) -> Iterator[Tuple[int, List[TokenizedStringEntry]]]:
159        """Returns tuple of (token, entries_list)) for all colliding tokens."""
160        for token, entries in self.token_to_entries.items():
161            if len(entries) > 1:
162                yield token, entries
163
164    def mark_removed(
165            self,
166            all_entries: Iterable[TokenizedStringEntry],
167            removal_date: Optional[datetime] = None
168    ) -> List[TokenizedStringEntry]:
169        """Marks entries missing from all_entries as having been removed.
170
171        The entries are assumed to represent the complete set of entries for the
172        database. Entries currently in the database not present in the provided
173        entries are marked with a removal date but remain in the database.
174        Entries in all_entries missing from the database are NOT added; call the
175        add function to add these.
176
177        Args:
178          all_entries: the complete set of strings present in the database
179          removal_date: the datetime for removed entries; today by default
180
181        Returns:
182          A list of entries marked as removed.
183        """
184        self._cache = None
185
186        if removal_date is None:
187            removal_date = datetime.now()
188
189        all_keys = frozenset(entry.key() for entry in all_entries)
190
191        removed = []
192
193        for entry in self._database.values():
194            if (entry.key() not in all_keys
195                    and (entry.date_removed is None
196                         or removal_date < entry.date_removed)):
197                # Add a removal date, or update it to the oldest date.
198                entry.date_removed = removal_date
199                removed.append(entry)
200
201        return removed
202
203    def add(self, entries: Iterable[TokenizedStringEntry]) -> None:
204        """Adds new entries and updates date_removed for existing entries."""
205        self._cache = None
206
207        for new_entry in entries:
208            # Update an existing entry or create a new one.
209            try:
210                entry = self._database[new_entry.key()]
211                entry.domain = new_entry.domain
212                entry.date_removed = None
213            except KeyError:
214                self._database[new_entry.key()] = TokenizedStringEntry(
215                    new_entry.token, new_entry.string, new_entry.domain)
216
217    def purge(
218        self,
219        date_removed_cutoff: Optional[datetime] = None
220    ) -> List[TokenizedStringEntry]:
221        """Removes and returns entries removed on/before date_removed_cutoff."""
222        self._cache = None
223
224        if date_removed_cutoff is None:
225            date_removed_cutoff = datetime.max
226
227        to_delete = [
228            entry for _, entry in self._database.items()
229            if entry.date_removed and entry.date_removed <= date_removed_cutoff
230        ]
231
232        for entry in to_delete:
233            del self._database[entry.key()]
234
235        return to_delete
236
237    def merge(self, *databases: 'Database') -> None:
238        """Merges two or more databases together, keeping the newest dates."""
239        self._cache = None
240
241        for other_db in databases:
242            for entry in other_db.entries():
243                key = entry.key()
244
245                if key in self._database:
246                    self._database[key].update_date_removed(entry.date_removed)
247                else:
248                    self._database[key] = entry
249
250    def filter(
251        self,
252        include: Iterable[Union[str, Pattern[str]]] = (),
253        exclude: Iterable[Union[str, Pattern[str]]] = (),
254        replace: Iterable[Tuple[Union[str, Pattern[str]], str]] = ()
255    ) -> None:
256        """Filters the database using regular expressions (strings or compiled).
257
258        Args:
259          include: regexes; only entries matching at least one are kept
260          exclude: regexes; entries matching any of these are removed
261          replace: (regex, str) tuples; replaces matching terms in all entries
262        """
263        self._cache = None
264
265        to_delete: List[_EntryKey] = []
266
267        if include:
268            include_re = [re.compile(pattern) for pattern in include]
269            to_delete.extend(
270                key for key, val in self._database.items()
271                if not any(rgx.search(val.string) for rgx in include_re))
272
273        if exclude:
274            exclude_re = [re.compile(pattern) for pattern in exclude]
275            to_delete.extend(key for key, val in self._database.items() if any(
276                rgx.search(val.string) for rgx in exclude_re))
277
278        for key in to_delete:
279            del self._database[key]
280
281        for search, replacement in replace:
282            search = re.compile(search)
283
284            for value in self._database.values():
285                value.string = search.sub(replacement, value.string)
286
287    def __len__(self) -> int:
288        """Returns the number of entries in the database."""
289        return len(self.entries())
290
291    def __str__(self) -> str:
292        """Outputs the database as CSV."""
293        csv_output = io.BytesIO()
294        write_csv(self, csv_output)
295        return csv_output.getvalue().decode()
296
297
298def parse_csv(fd) -> Iterable[TokenizedStringEntry]:
299    """Parses TokenizedStringEntries from a CSV token database file."""
300    for line in csv.reader(fd):
301        try:
302            token_str, date_str, string_literal = line
303
304            token = int(token_str, 16)
305            date = (datetime.strptime(date_str, DATE_FORMAT)
306                    if date_str.strip() else None)
307
308            yield TokenizedStringEntry(token, string_literal, DEFAULT_DOMAIN,
309                                       date)
310        except (ValueError, UnicodeDecodeError) as err:
311            _LOG.error('Failed to parse tokenized string entry %s: %s', line,
312                       err)
313
314
315def write_csv(database: Database, fd: BinaryIO) -> None:
316    """Writes the database as CSV to the provided binary file."""
317    for entry in sorted(database.entries()):
318        # Align the CSV output to 10-character columns for improved readability.
319        # Use \n instead of RFC 4180's \r\n.
320        fd.write('{:08x},{:10},"{}"\n'.format(
321            entry.token,
322            entry.date_removed.strftime(DATE_FORMAT) if entry.date_removed else
323            '', entry.string.replace('"', '""')).encode())  # escape " as ""
324
325
326class _BinaryFileFormat(NamedTuple):
327    """Attributes of the binary token database file format."""
328
329    magic: bytes = b'TOKENS\0\0'
330    header: struct.Struct = struct.Struct('<8sI4x')
331    entry: struct.Struct = struct.Struct('<IBBH')
332
333
334BINARY_FORMAT = _BinaryFileFormat()
335
336
337class DatabaseFormatError(Exception):
338    """Failed to parse a token database file."""
339
340
341def file_is_binary_database(fd: BinaryIO) -> bool:
342    """True if the file starts with the binary token database magic string."""
343    try:
344        fd.seek(0)
345        magic = fd.read(len(BINARY_FORMAT.magic))
346        fd.seek(0)
347        return BINARY_FORMAT.magic == magic
348    except IOError:
349        return False
350
351
352def _check_that_file_is_csv_database(path: Path) -> None:
353    """Raises an error unless the path appears to be a CSV token database."""
354    try:
355        with path.open('rb') as fd:
356            data = fd.read(8)  # Read 8 bytes, which should be the first token.
357
358        if not data:
359            return  # File is empty, which is valid CSV.
360
361        if len(data) != 8:
362            raise DatabaseFormatError(
363                f'Attempted to read {path} as a CSV token database, but the '
364                f'file is too short ({len(data)} B)')
365
366        # Make sure the first 8 chars are a valid hexadecimal number.
367        _ = int(data.decode(), 16)
368    except (IOError, UnicodeDecodeError, ValueError) as err:
369        raise DatabaseFormatError(
370            f'Encountered error while reading {path} as a CSV token database'
371        ) from err
372
373
374def parse_binary(fd: BinaryIO) -> Iterable[TokenizedStringEntry]:
375    """Parses TokenizedStringEntries from a binary token database file."""
376    magic, entry_count = BINARY_FORMAT.header.unpack(
377        fd.read(BINARY_FORMAT.header.size))
378
379    if magic != BINARY_FORMAT.magic:
380        raise DatabaseFormatError(
381            f'Binary token database magic number mismatch (found {magic!r}, '
382            f'expected {BINARY_FORMAT.magic!r}) while reading from {fd}')
383
384    entries = []
385
386    for _ in range(entry_count):
387        token, day, month, year = BINARY_FORMAT.entry.unpack(
388            fd.read(BINARY_FORMAT.entry.size))
389
390        try:
391            date_removed: Optional[datetime] = datetime(year, month, day)
392        except ValueError:
393            date_removed = None
394
395        entries.append((token, date_removed))
396
397    # Read the entire string table and define a function for looking up strings.
398    string_table = fd.read()
399
400    def read_string(start):
401        end = string_table.find(b'\0', start)
402        return string_table[start:string_table.find(b'\0', start)].decode(
403        ), end + 1
404
405    offset = 0
406    for token, removed in entries:
407        string, offset = read_string(offset)
408        yield TokenizedStringEntry(token, string, DEFAULT_DOMAIN, removed)
409
410
411def write_binary(database: Database, fd: BinaryIO) -> None:
412    """Writes the database as packed binary to the provided binary file."""
413    entries = sorted(database.entries())
414
415    fd.write(BINARY_FORMAT.header.pack(BINARY_FORMAT.magic, len(entries)))
416
417    string_table = bytearray()
418
419    for entry in entries:
420        if entry.date_removed:
421            removed_day = entry.date_removed.day
422            removed_month = entry.date_removed.month
423            removed_year = entry.date_removed.year
424        else:
425            # If there is no removal date, use the special value 0xffffffff for
426            # the day/month/year. That ensures that still-present tokens appear
427            # as the newest tokens when sorted by removal date.
428            removed_day = 0xff
429            removed_month = 0xff
430            removed_year = 0xffff
431
432        string_table += entry.string.encode()
433        string_table.append(0)
434
435        fd.write(
436            BINARY_FORMAT.entry.pack(entry.token, removed_day, removed_month,
437                                     removed_year))
438
439    fd.write(string_table)
440
441
442class DatabaseFile(Database):
443    """A token database that is associated with a particular file.
444
445    This class adds the write_to_file() method that writes to file from which it
446    was created in the correct format (CSV or binary).
447    """
448    def __init__(self, path: Union[Path, str]):
449        self.path = Path(path)
450
451        # Read the path as a packed binary file.
452        with self.path.open('rb') as fd:
453            if file_is_binary_database(fd):
454                super().__init__(parse_binary(fd))
455                self._export = write_binary
456                return
457
458        # Read the path as a CSV file.
459        _check_that_file_is_csv_database(self.path)
460        with self.path.open('r', newline='', encoding='utf-8') as file:
461            super().__init__(parse_csv(file))
462            self._export = write_csv
463
464    def write_to_file(self, path: Optional[Union[Path, str]] = None) -> None:
465        """Exports in the original format to the original or provided path."""
466        with open(self.path if path is None else path, 'wb') as fd:
467            self._export(self, fd)
468