1#!/usr/bin/env python3
2# ################################################################
3# Copyright (c) 2020-2020, Facebook, Inc.
4# All rights reserved.
5#
6# This source code is licensed under both the BSD-style license (found in the
7# LICENSE file in the root directory of this source tree) and the GPLv2 (found
8# in the COPYING file in the root directory of this source tree).
9# You may select, at your option, one of the above-listed licenses.
10# ##########################################################################
11
12import argparse
13import contextlib
14import os
15import re
16import shutil
17import sys
18from typing import Optional
19
20
21INCLUDED_SUBDIRS = ["common", "compress", "decompress"]
22
23SKIPPED_FILES = [
24    "common/mem.h",
25    "common/zstd_deps.h",
26    "common/pool.c",
27    "common/pool.h",
28    "common/threading.c",
29    "common/threading.h",
30    "compress/zstdmt_compress.h",
31    "compress/zstdmt_compress.c",
32]
33
34XXHASH_FILES = [
35    "common/xxhash.c",
36    "common/xxhash.h",
37]
38
39
40class FileLines(object):
41    def __init__(self, filename):
42        self.filename = filename
43        with open(self.filename, "r") as f:
44            self.lines = f.readlines()
45
46    def write(self):
47        with open(self.filename, "w") as f:
48            f.write("".join(self.lines))
49
50
51class PartialPreprocessor(object):
52    """
53    Looks for simple ifdefs and ifndefs and replaces them.
54    Handles && and ||.
55    Has fancy logic to handle translating elifs to ifs.
56    Only looks for macros in the first part of the expression with no
57    parens.
58    Does not handle multi-line macros (only looks in first line).
59    """
60    def __init__(self, defs: [(str, Optional[str])], replaces: [(str, str)], undefs: [str]):
61        MACRO_GROUP = r"(?P<macro>[a-zA-Z_][a-zA-Z_0-9]*)"
62        ELIF_GROUP = r"(?P<elif>el)?"
63        OP_GROUP = r"(?P<op>&&|\|\|)?"
64
65        self._defs = {macro:value for macro, value in defs}
66        self._replaces = {macro:value for macro, value in replaces}
67        self._defs.update(self._replaces)
68        self._undefs = set(undefs)
69
70        self._define = re.compile(r"\s*#\s*define")
71        self._if = re.compile(r"\s*#\s*if")
72        self._elif = re.compile(r"\s*#\s*(?P<elif>el)if")
73        self._else = re.compile(r"\s*#\s*(?P<else>else)")
74        self._endif = re.compile(r"\s*#\s*endif")
75
76        self._ifdef = re.compile(fr"\s*#\s*if(?P<not>n)?def {MACRO_GROUP}\s*")
77        self._if_defined = re.compile(
78            fr"\s*#\s*{ELIF_GROUP}if\s+(?P<not>!)?\s*defined\s*\(\s*{MACRO_GROUP}\s*\)\s*{OP_GROUP}"
79        )
80        self._if_defined_value = re.compile(
81            fr"\s*#\s*{ELIF_GROUP}if\s+defined\s*\(\s*{MACRO_GROUP}\s*\)\s*"
82            fr"(?P<op>&&)\s*"
83            fr"(?P<openp>\()?\s*"
84            fr"(?P<macro2>[a-zA-Z_][a-zA-Z_0-9]*)\s*"
85            fr"(?P<cmp>[=><!]+)\s*"
86            fr"(?P<value>[0-9]*)\s*"
87            fr"(?P<closep>\))?\s*"
88        )
89        self._if_true = re.compile(
90            fr"\s*#\s*{ELIF_GROUP}if\s+{MACRO_GROUP}\s*{OP_GROUP}"
91        )
92
93        self._c_comment = re.compile(r"/\*.*?\*/")
94        self._cpp_comment = re.compile(r"//")
95
96    def _log(self, *args, **kwargs):
97        print(*args, **kwargs)
98
99    def _strip_comments(self, line):
100        # First strip c-style comments (may include //)
101        while True:
102            m = self._c_comment.search(line)
103            if m is None:
104                break
105            line = line[:m.start()] + line[m.end():]
106
107        # Then strip cpp-style comments
108        m = self._cpp_comment.search(line)
109        if m is not None:
110            line = line[:m.start()]
111
112        return line
113
114    def _fixup_indentation(self, macro, replace: [str]):
115        if len(replace) == 0:
116            return replace
117        if len(replace) == 1 and self._define.match(replace[0]) is None:
118            # If there is only one line, only replace defines
119            return replace
120
121
122        all_pound = True
123        for line in replace:
124            if not line.startswith('#'):
125                all_pound = False
126        if all_pound:
127            replace = [line[1:] for line in replace]
128
129        min_spaces = len(replace[0])
130        for line in replace:
131            spaces = 0
132            for i, c in enumerate(line):
133                if c != ' ':
134                    # Non-preprocessor line ==> skip the fixup
135                    if not all_pound and c != '#':
136                        return replace
137                    spaces = i
138                    break
139            min_spaces = min(min_spaces, spaces)
140
141        replace = [line[min_spaces:] for line in replace]
142
143        if all_pound:
144            replace = ["#" + line for line in replace]
145
146        return replace
147
148    def _handle_if_block(self, macro, idx, is_true, prepend):
149        """
150        Remove the #if or #elif block starting on this line.
151        """
152        REMOVE_ONE = 0
153        KEEP_ONE = 1
154        REMOVE_REST = 2
155
156        if is_true:
157            state = KEEP_ONE
158        else:
159            state = REMOVE_ONE
160
161        line = self._inlines[idx]
162        is_if = self._if.match(line) is not None
163        assert is_if or self._elif.match(line) is not None
164        depth = 0
165
166        start_idx = idx
167
168        idx += 1
169        replace = prepend
170        finished = False
171        while idx < len(self._inlines):
172            line = self._inlines[idx]
173            # Nested if statement
174            if self._if.match(line):
175                depth += 1
176                idx += 1
177                continue
178            # We're inside a nested statement
179            if depth > 0:
180                if self._endif.match(line):
181                    depth -= 1
182                idx += 1
183                continue
184
185            # We're at the original depth
186
187            # Looking only for an endif.
188            # We've found a true statement, but haven't
189            # completely elided the if block, so we just
190            # remove the remainder.
191            if state == REMOVE_REST:
192                if self._endif.match(line):
193                    if is_if:
194                        # Remove the endif because we took the first if
195                        idx += 1
196                    finished = True
197                    break
198                idx += 1
199                continue
200
201            if state == KEEP_ONE:
202                m = self._elif.match(line)
203                if self._endif.match(line):
204                    replace += self._inlines[start_idx + 1:idx]
205                    idx += 1
206                    finished = True
207                    break
208                if self._elif.match(line) or self._else.match(line):
209                    replace += self._inlines[start_idx + 1:idx]
210                    state = REMOVE_REST
211                idx += 1
212                continue
213
214            if state == REMOVE_ONE:
215                m = self._elif.match(line)
216                if m is not None:
217                    if is_if:
218                        idx += 1
219                        b = m.start('elif')
220                        e = m.end('elif')
221                        assert e - b == 2
222                        replace.append(line[:b] + line[e:])
223                    finished = True
224                    break
225                m = self._else.match(line)
226                if m is not None:
227                    if is_if:
228                        idx += 1
229                        while self._endif.match(self._inlines[idx]) is None:
230                            replace.append(self._inlines[idx])
231                            idx += 1
232                        idx += 1
233                    finished = True
234                    break
235                if self._endif.match(line):
236                    if is_if:
237                        # Remove the endif because no other elifs
238                        idx += 1
239                    finished = True
240                    break
241                idx += 1
242                continue
243        if not finished:
244            raise RuntimeError("Unterminated if block!")
245
246        replace = self._fixup_indentation(macro, replace)
247
248        self._log(f"\tHardwiring {macro}")
249        if start_idx > 0:
250            self._log(f"\t\t  {self._inlines[start_idx - 1][:-1]}")
251        for x in range(start_idx, idx):
252            self._log(f"\t\t- {self._inlines[x][:-1]}")
253        for line in replace:
254            self._log(f"\t\t+ {line[:-1]}")
255        if idx < len(self._inlines):
256            self._log(f"\t\t  {self._inlines[idx][:-1]}")
257
258        return idx, replace
259
260    def _preprocess_once(self):
261        outlines = []
262        idx = 0
263        changed = False
264        while idx < len(self._inlines):
265            line = self._inlines[idx]
266            sline = self._strip_comments(line)
267            m = self._ifdef.fullmatch(sline)
268            if_true = False
269            if m is None:
270                m = self._if_defined_value.fullmatch(sline)
271            if m is None:
272                m = self._if_defined.match(sline)
273            if m is None:
274                m = self._if_true.match(sline)
275                if_true = (m is not None)
276            if m is None:
277                outlines.append(line)
278                idx += 1
279                continue
280
281            groups = m.groupdict()
282            macro = groups['macro']
283            op = groups.get('op')
284
285            if not (macro in self._defs or macro in self._undefs):
286                outlines.append(line)
287                idx += 1
288                continue
289
290            defined = macro in self._defs
291
292            # Needed variables set:
293            # resolved: Is the statement fully resolved?
294            # is_true: If resolved, is the statement true?
295            ifdef = False
296            if if_true:
297                if not defined:
298                    outlines.append(line)
299                    idx += 1
300                    continue
301
302                defined_value = self._defs[macro]
303                is_int = True
304                try:
305                    defined_value = int(defined_value)
306                except TypeError:
307                    is_int = False
308                except ValueError:
309                    is_int = False
310
311                resolved = is_int
312                is_true = (defined_value != 0)
313
314                if resolved and op is not None:
315                    if op == '&&':
316                        resolved = not is_true
317                    else:
318                        assert op == '||'
319                        resolved = is_true
320
321            else:
322                ifdef = groups.get('not') is None
323                elseif = groups.get('elif') is not None
324
325                macro2 = groups.get('macro2')
326                cmp = groups.get('cmp')
327                value = groups.get('value')
328                openp = groups.get('openp')
329                closep = groups.get('closep')
330
331                is_true = (ifdef == defined)
332                resolved = True
333                if op is not None:
334                    if op == '&&':
335                        resolved = not is_true
336                    else:
337                        assert op == '||'
338                        resolved = is_true
339
340                if macro2 is not None and not resolved:
341                    assert ifdef and defined and op == '&&' and cmp is not None
342                    # If the statment is true, but we have a single value check, then
343                    # check the value.
344                    defined_value = self._defs[macro]
345                    are_ints = True
346                    try:
347                        defined_value = int(defined_value)
348                        value = int(value)
349                    except TypeError:
350                        are_ints = False
351                    except ValueError:
352                        are_ints = False
353                    if (
354                            macro == macro2 and
355                            ((openp is None) == (closep is None)) and
356                            are_ints
357                    ):
358                        resolved = True
359                        if cmp == '<':
360                            is_true = defined_value < value
361                        elif cmp == '<=':
362                            is_true = defined_value <= value
363                        elif cmp == '==':
364                            is_true = defined_value == value
365                        elif cmp == '!=':
366                            is_true = defined_value != value
367                        elif cmp == '>=':
368                            is_true = defined_value >= value
369                        elif cmp == '>':
370                            is_true = defined_value > value
371                        else:
372                            resolved = False
373
374                if op is not None and not resolved:
375                    # Remove the first op in the line + spaces
376                    if op == '&&':
377                        opre = op
378                    else:
379                        assert op == '||'
380                        opre = r'\|\|'
381                    needle = re.compile(fr"(?P<if>\s*#\s*(el)?if\s+).*?(?P<op>{opre}\s*)")
382                    match = needle.match(line)
383                    assert match is not None
384                    newline = line[:match.end('if')] + line[match.end('op'):]
385
386                    self._log(f"\tHardwiring partially resolved {macro}")
387                    self._log(f"\t\t- {line[:-1]}")
388                    self._log(f"\t\t+ {newline[:-1]}")
389
390                    outlines.append(newline)
391                    idx += 1
392                    continue
393
394            # Skip any statements we cannot fully compute
395            if not resolved:
396                outlines.append(line)
397                idx += 1
398                continue
399
400            prepend = []
401            if macro in self._replaces:
402                assert not ifdef
403                assert op is None
404                value = self._replaces.pop(macro)
405                prepend = [f"#define {macro} {value}\n"]
406
407            idx, replace = self._handle_if_block(macro, idx, is_true, prepend)
408            outlines += replace
409            changed = True
410
411        return changed, outlines
412
413    def preprocess(self, filename):
414        with open(filename, 'r') as f:
415            self._inlines = f.readlines()
416        changed = True
417        iters = 0
418        while changed:
419            iters += 1
420            changed, outlines = self._preprocess_once()
421            self._inlines = outlines
422
423        with open(filename, 'w') as f:
424            f.write(''.join(self._inlines))
425
426
427class Freestanding(object):
428    def __init__(
429            self, zstd_deps: str, mem: str, source_lib: str, output_lib: str,
430            external_xxhash: bool, xxh64_state: Optional[str],
431            xxh64_prefix: Optional[str], rewritten_includes: [(str, str)],
432            defs: [(str, Optional[str])], replaces: [(str, str)],
433            undefs: [str], excludes: [str]
434    ):
435        self._zstd_deps = zstd_deps
436        self._mem = mem
437        self._src_lib = source_lib
438        self._dst_lib = output_lib
439        self._external_xxhash = external_xxhash
440        self._xxh64_state = xxh64_state
441        self._xxh64_prefix = xxh64_prefix
442        self._rewritten_includes = rewritten_includes
443        self._defs = defs
444        self._replaces = replaces
445        self._undefs = undefs
446        self._excludes = excludes
447
448    def _dst_lib_file_paths(self):
449        """
450        Yields all the file paths in the dst_lib.
451        """
452        for root, dirname, filenames in os.walk(self._dst_lib):
453            for filename in filenames:
454                filepath = os.path.join(root, filename)
455                yield filepath
456
457    def _log(self, *args, **kwargs):
458        print(*args, **kwargs)
459
460    def _copy_file(self, lib_path):
461        if not (lib_path.endswith(".c") or lib_path.endswith(".h")):
462            return
463        if lib_path in SKIPPED_FILES:
464            self._log(f"\tSkipping file: {lib_path}")
465            return
466        if self._external_xxhash and lib_path in XXHASH_FILES:
467            self._log(f"\tSkipping xxhash file: {lib_path}")
468            return
469
470        src_path = os.path.join(self._src_lib, lib_path)
471        dst_path = os.path.join(self._dst_lib, lib_path)
472        self._log(f"\tCopying: {src_path} -> {dst_path}")
473        shutil.copyfile(src_path, dst_path)
474
475    def _copy_source_lib(self):
476        self._log("Copying source library into output library")
477
478        assert os.path.exists(self._src_lib)
479        os.makedirs(self._dst_lib, exist_ok=True)
480        self._copy_file("zstd.h")
481        for subdir in INCLUDED_SUBDIRS:
482            src_dir = os.path.join(self._src_lib, subdir)
483            dst_dir = os.path.join(self._dst_lib, subdir)
484
485            assert os.path.exists(src_dir)
486            os.makedirs(dst_dir, exist_ok=True)
487
488            for filename in os.listdir(src_dir):
489                lib_path = os.path.join(subdir, filename)
490                self._copy_file(lib_path)
491
492    def _copy_zstd_deps(self):
493        dst_zstd_deps = os.path.join(self._dst_lib, "common", "zstd_deps.h")
494        self._log(f"Copying zstd_deps: {self._zstd_deps} -> {dst_zstd_deps}")
495        shutil.copyfile(self._zstd_deps, dst_zstd_deps)
496
497    def _copy_mem(self):
498        dst_mem = os.path.join(self._dst_lib, "common", "mem.h")
499        self._log(f"Copying mem: {self._mem} -> {dst_mem}")
500        shutil.copyfile(self._mem, dst_mem)
501
502    def _hardwire_preprocessor(self, name: str, value: Optional[str] = None, undef=False):
503        """
504        If value=None then hardwire that it is defined, but not what the value is.
505        If undef=True then value must be None.
506        If value='' then the macro is defined to '' exactly.
507        """
508        assert not (undef and value is not None)
509        for filepath in self._dst_lib_file_paths():
510            file = FileLines(filepath)
511
512    def _hardwire_defines(self):
513        self._log("Hardwiring macros")
514        partial_preprocessor = PartialPreprocessor(self._defs, self._replaces, self._undefs)
515        for filepath in self._dst_lib_file_paths():
516            partial_preprocessor.preprocess(filepath)
517
518    def _remove_excludes(self):
519        self._log("Removing excluded sections")
520        for exclude in self._excludes:
521            self._log(f"\tRemoving excluded sections for: {exclude}")
522            begin_re = re.compile(f"BEGIN {exclude}")
523            end_re = re.compile(f"END {exclude}")
524            for filepath in self._dst_lib_file_paths():
525                file = FileLines(filepath)
526                outlines = []
527                skipped = []
528                emit = True
529                for line in file.lines:
530                    if emit and begin_re.search(line) is not None:
531                        assert end_re.search(line) is None
532                        emit = False
533                    if emit:
534                        outlines.append(line)
535                    else:
536                        skipped.append(line)
537                        if end_re.search(line) is not None:
538                            assert begin_re.search(line) is None
539                            self._log(f"\t\tRemoving excluded section: {exclude}")
540                            for s in skipped:
541                                self._log(f"\t\t\t- {s}")
542                            emit = True
543                            skipped = []
544                if not emit:
545                    raise RuntimeError("Excluded section unfinished!")
546                file.lines = outlines
547                file.write()
548
549    def _rewrite_include(self, original, rewritten):
550        self._log(f"\tRewriting include: {original} -> {rewritten}")
551        regex = re.compile(f"\\s*#\\s*include\\s*(?P<include>{original})")
552        for filepath in self._dst_lib_file_paths():
553            file = FileLines(filepath)
554            for i, line in enumerate(file.lines):
555                match = regex.match(line)
556                if match is None:
557                    continue
558                s = match.start('include')
559                e = match.end('include')
560                file.lines[i] = line[:s] + rewritten + line[e:]
561            file.write()
562
563    def _rewrite_includes(self):
564        self._log("Rewriting includes")
565        for original, rewritten in self._rewritten_includes:
566            self._rewrite_include(original, rewritten)
567
568    def _replace_xxh64_prefix(self):
569        if self._xxh64_prefix is None:
570            return
571        self._log(f"Replacing XXH64 prefix with {self._xxh64_prefix}")
572        replacements = []
573        if self._xxh64_state is not None:
574            replacements.append(
575                (re.compile(r"([^\w]|^)(?P<orig>XXH64_state_t)([^\w]|$)"), self._xxh64_state)
576            )
577        if self._xxh64_prefix is not None:
578            replacements.append(
579                (re.compile(r"([^\w]|^)(?P<orig>XXH64)_"), self._xxh64_prefix)
580            )
581        for filepath in self._dst_lib_file_paths():
582            file = FileLines(filepath)
583            for i, line in enumerate(file.lines):
584                modified = False
585                for regex, replacement in replacements:
586                    match = regex.search(line)
587                    while match is not None:
588                        modified = True
589                        b = match.start('orig')
590                        e = match.end('orig')
591                        line = line[:b] + replacement + line[e:]
592                        match = regex.search(line)
593                if modified:
594                    self._log(f"\t- {file.lines[i][:-1]}")
595                    self._log(f"\t+ {line[:-1]}")
596                file.lines[i] = line
597            file.write()
598
599    def go(self):
600        self._copy_source_lib()
601        self._copy_zstd_deps()
602        self._copy_mem()
603        self._hardwire_defines()
604        self._remove_excludes()
605        self._rewrite_includes()
606        self._replace_xxh64_prefix()
607
608
609def parse_optional_pair(defines: [str]) -> [(str, Optional[str])]:
610    output = []
611    for define in defines:
612        parsed = define.split('=')
613        if len(parsed) == 1:
614            output.append((parsed[0], None))
615        elif len(parsed) == 2:
616            output.append((parsed[0], parsed[1]))
617        else:
618            raise RuntimeError(f"Bad define: {define}")
619    return output
620
621
622def parse_pair(rewritten_includes: [str]) -> [(str, str)]:
623    output = []
624    for rewritten_include in rewritten_includes:
625        parsed = rewritten_include.split('=')
626        if len(parsed) == 2:
627            output.append((parsed[0], parsed[1]))
628        else:
629            raise RuntimeError(f"Bad rewritten include: {rewritten_include}")
630    return output
631
632
633
634def main(name, args):
635    parser = argparse.ArgumentParser(prog=name)
636    parser.add_argument("--zstd-deps", default="zstd_deps.h", help="Zstd dependencies file")
637    parser.add_argument("--mem", default="mem.h", help="Memory module")
638    parser.add_argument("--source-lib", default="../../lib", help="Location of the zstd library")
639    parser.add_argument("--output-lib", default="./freestanding_lib", help="Where to output the freestanding zstd library")
640    parser.add_argument("--xxhash", default=None, help="Alternate external xxhash include e.g. --xxhash='<xxhash.h>'. If set xxhash is not included.")
641    parser.add_argument("--xxh64-state", default=None, help="Alternate XXH64 state type (excluding _) e.g. --xxh64-state='struct xxh64_state'")
642    parser.add_argument("--xxh64-prefix", default=None, help="Alternate XXH64 function prefix (excluding _) e.g. --xxh64-prefix=xxh64")
643    parser.add_argument("--rewrite-include", default=[], dest="rewritten_includes", action="append", help="Rewrite an include REGEX=NEW (e.g. '<stddef\\.h>=<linux/types.h>')")
644    parser.add_argument("-D", "--define", default=[], dest="defs", action="append", help="Pre-define this macro (can be passed multiple times)")
645    parser.add_argument("-U", "--undefine", default=[], dest="undefs", action="append", help="Pre-undefine this macro (can be passed mutliple times)")
646    parser.add_argument("-R", "--replace", default=[], dest="replaces", action="append", help="Pre-define this macro and replace the first ifndef block with its definition")
647    parser.add_argument("-E", "--exclude", default=[], dest="excludes", action="append", help="Exclude all lines between 'BEGIN <EXCLUDE>' and 'END <EXCLUDE>'")
648    args = parser.parse_args(args)
649
650    # Always remove threading
651    if "ZSTD_MULTITHREAD" not in args.undefs:
652        args.undefs.append("ZSTD_MULTITHREAD")
653
654    args.defs = parse_optional_pair(args.defs)
655    for name, _ in args.defs:
656        if name in args.undefs:
657            raise RuntimeError(f"{name} is both defined and undefined!")
658
659    args.replaces = parse_pair(args.replaces)
660    for name, _ in args.replaces:
661        if name in args.undefs or name in args.defs:
662            raise RuntimeError(f"{name} is both replaced and (un)defined!")
663
664    args.rewritten_includes = parse_pair(args.rewritten_includes)
665
666    external_xxhash = False
667    if args.xxhash is not None:
668        external_xxhash = True
669        args.rewritten_includes.append(('"(\\.\\./common/)?xxhash.h"', args.xxhash))
670
671    if args.xxh64_prefix is not None:
672        if not external_xxhash:
673            raise RuntimeError("--xxh64-prefix may only be used with --xxhash provided")
674
675    if args.xxh64_state is not None:
676        if not external_xxhash:
677            raise RuntimeError("--xxh64-state may only be used with --xxhash provided")
678
679    Freestanding(
680        args.zstd_deps,
681        args.mem,
682        args.source_lib,
683        args.output_lib,
684        external_xxhash,
685        args.xxh64_state,
686        args.xxh64_prefix,
687        args.rewritten_includes,
688        args.defs,
689        args.replaces,
690        args.undefs,
691        args.excludes
692    ).go()
693
694if __name__ == "__main__":
695    main(sys.argv[0], sys.argv[1:])
696