1#!/usr/bin/env python3 2# ################################################################ 3# Copyright (c) 2020-2020, Facebook, Inc. 4# All rights reserved. 5# 6# This source code is licensed under both the BSD-style license (found in the 7# LICENSE file in the root directory of this source tree) and the GPLv2 (found 8# in the COPYING file in the root directory of this source tree). 9# You may select, at your option, one of the above-listed licenses. 10# ########################################################################## 11 12import argparse 13import contextlib 14import os 15import re 16import shutil 17import sys 18from typing import Optional 19 20 21INCLUDED_SUBDIRS = ["common", "compress", "decompress"] 22 23SKIPPED_FILES = [ 24 "common/mem.h", 25 "common/zstd_deps.h", 26 "common/pool.c", 27 "common/pool.h", 28 "common/threading.c", 29 "common/threading.h", 30 "compress/zstdmt_compress.h", 31 "compress/zstdmt_compress.c", 32] 33 34XXHASH_FILES = [ 35 "common/xxhash.c", 36 "common/xxhash.h", 37] 38 39 40class FileLines(object): 41 def __init__(self, filename): 42 self.filename = filename 43 with open(self.filename, "r") as f: 44 self.lines = f.readlines() 45 46 def write(self): 47 with open(self.filename, "w") as f: 48 f.write("".join(self.lines)) 49 50 51class PartialPreprocessor(object): 52 """ 53 Looks for simple ifdefs and ifndefs and replaces them. 54 Handles && and ||. 55 Has fancy logic to handle translating elifs to ifs. 56 Only looks for macros in the first part of the expression with no 57 parens. 58 Does not handle multi-line macros (only looks in first line). 59 """ 60 def __init__(self, defs: [(str, Optional[str])], replaces: [(str, str)], undefs: [str]): 61 MACRO_GROUP = r"(?P<macro>[a-zA-Z_][a-zA-Z_0-9]*)" 62 ELIF_GROUP = r"(?P<elif>el)?" 63 OP_GROUP = r"(?P<op>&&|\|\|)?" 64 65 self._defs = {macro:value for macro, value in defs} 66 self._replaces = {macro:value for macro, value in replaces} 67 self._defs.update(self._replaces) 68 self._undefs = set(undefs) 69 70 self._define = re.compile(r"\s*#\s*define") 71 self._if = re.compile(r"\s*#\s*if") 72 self._elif = re.compile(r"\s*#\s*(?P<elif>el)if") 73 self._else = re.compile(r"\s*#\s*(?P<else>else)") 74 self._endif = re.compile(r"\s*#\s*endif") 75 76 self._ifdef = re.compile(fr"\s*#\s*if(?P<not>n)?def {MACRO_GROUP}\s*") 77 self._if_defined = re.compile( 78 fr"\s*#\s*{ELIF_GROUP}if\s+(?P<not>!)?\s*defined\s*\(\s*{MACRO_GROUP}\s*\)\s*{OP_GROUP}" 79 ) 80 self._if_defined_value = re.compile( 81 fr"\s*#\s*{ELIF_GROUP}if\s+defined\s*\(\s*{MACRO_GROUP}\s*\)\s*" 82 fr"(?P<op>&&)\s*" 83 fr"(?P<openp>\()?\s*" 84 fr"(?P<macro2>[a-zA-Z_][a-zA-Z_0-9]*)\s*" 85 fr"(?P<cmp>[=><!]+)\s*" 86 fr"(?P<value>[0-9]*)\s*" 87 fr"(?P<closep>\))?\s*" 88 ) 89 self._if_true = re.compile( 90 fr"\s*#\s*{ELIF_GROUP}if\s+{MACRO_GROUP}\s*{OP_GROUP}" 91 ) 92 93 self._c_comment = re.compile(r"/\*.*?\*/") 94 self._cpp_comment = re.compile(r"//") 95 96 def _log(self, *args, **kwargs): 97 print(*args, **kwargs) 98 99 def _strip_comments(self, line): 100 # First strip c-style comments (may include //) 101 while True: 102 m = self._c_comment.search(line) 103 if m is None: 104 break 105 line = line[:m.start()] + line[m.end():] 106 107 # Then strip cpp-style comments 108 m = self._cpp_comment.search(line) 109 if m is not None: 110 line = line[:m.start()] 111 112 return line 113 114 def _fixup_indentation(self, macro, replace: [str]): 115 if len(replace) == 0: 116 return replace 117 if len(replace) == 1 and self._define.match(replace[0]) is None: 118 # If there is only one line, only replace defines 119 return replace 120 121 122 all_pound = True 123 for line in replace: 124 if not line.startswith('#'): 125 all_pound = False 126 if all_pound: 127 replace = [line[1:] for line in replace] 128 129 min_spaces = len(replace[0]) 130 for line in replace: 131 spaces = 0 132 for i, c in enumerate(line): 133 if c != ' ': 134 # Non-preprocessor line ==> skip the fixup 135 if not all_pound and c != '#': 136 return replace 137 spaces = i 138 break 139 min_spaces = min(min_spaces, spaces) 140 141 replace = [line[min_spaces:] for line in replace] 142 143 if all_pound: 144 replace = ["#" + line for line in replace] 145 146 return replace 147 148 def _handle_if_block(self, macro, idx, is_true, prepend): 149 """ 150 Remove the #if or #elif block starting on this line. 151 """ 152 REMOVE_ONE = 0 153 KEEP_ONE = 1 154 REMOVE_REST = 2 155 156 if is_true: 157 state = KEEP_ONE 158 else: 159 state = REMOVE_ONE 160 161 line = self._inlines[idx] 162 is_if = self._if.match(line) is not None 163 assert is_if or self._elif.match(line) is not None 164 depth = 0 165 166 start_idx = idx 167 168 idx += 1 169 replace = prepend 170 finished = False 171 while idx < len(self._inlines): 172 line = self._inlines[idx] 173 # Nested if statement 174 if self._if.match(line): 175 depth += 1 176 idx += 1 177 continue 178 # We're inside a nested statement 179 if depth > 0: 180 if self._endif.match(line): 181 depth -= 1 182 idx += 1 183 continue 184 185 # We're at the original depth 186 187 # Looking only for an endif. 188 # We've found a true statement, but haven't 189 # completely elided the if block, so we just 190 # remove the remainder. 191 if state == REMOVE_REST: 192 if self._endif.match(line): 193 if is_if: 194 # Remove the endif because we took the first if 195 idx += 1 196 finished = True 197 break 198 idx += 1 199 continue 200 201 if state == KEEP_ONE: 202 m = self._elif.match(line) 203 if self._endif.match(line): 204 replace += self._inlines[start_idx + 1:idx] 205 idx += 1 206 finished = True 207 break 208 if self._elif.match(line) or self._else.match(line): 209 replace += self._inlines[start_idx + 1:idx] 210 state = REMOVE_REST 211 idx += 1 212 continue 213 214 if state == REMOVE_ONE: 215 m = self._elif.match(line) 216 if m is not None: 217 if is_if: 218 idx += 1 219 b = m.start('elif') 220 e = m.end('elif') 221 assert e - b == 2 222 replace.append(line[:b] + line[e:]) 223 finished = True 224 break 225 m = self._else.match(line) 226 if m is not None: 227 if is_if: 228 idx += 1 229 while self._endif.match(self._inlines[idx]) is None: 230 replace.append(self._inlines[idx]) 231 idx += 1 232 idx += 1 233 finished = True 234 break 235 if self._endif.match(line): 236 if is_if: 237 # Remove the endif because no other elifs 238 idx += 1 239 finished = True 240 break 241 idx += 1 242 continue 243 if not finished: 244 raise RuntimeError("Unterminated if block!") 245 246 replace = self._fixup_indentation(macro, replace) 247 248 self._log(f"\tHardwiring {macro}") 249 if start_idx > 0: 250 self._log(f"\t\t {self._inlines[start_idx - 1][:-1]}") 251 for x in range(start_idx, idx): 252 self._log(f"\t\t- {self._inlines[x][:-1]}") 253 for line in replace: 254 self._log(f"\t\t+ {line[:-1]}") 255 if idx < len(self._inlines): 256 self._log(f"\t\t {self._inlines[idx][:-1]}") 257 258 return idx, replace 259 260 def _preprocess_once(self): 261 outlines = [] 262 idx = 0 263 changed = False 264 while idx < len(self._inlines): 265 line = self._inlines[idx] 266 sline = self._strip_comments(line) 267 m = self._ifdef.fullmatch(sline) 268 if_true = False 269 if m is None: 270 m = self._if_defined_value.fullmatch(sline) 271 if m is None: 272 m = self._if_defined.match(sline) 273 if m is None: 274 m = self._if_true.match(sline) 275 if_true = (m is not None) 276 if m is None: 277 outlines.append(line) 278 idx += 1 279 continue 280 281 groups = m.groupdict() 282 macro = groups['macro'] 283 op = groups.get('op') 284 285 if not (macro in self._defs or macro in self._undefs): 286 outlines.append(line) 287 idx += 1 288 continue 289 290 defined = macro in self._defs 291 292 # Needed variables set: 293 # resolved: Is the statement fully resolved? 294 # is_true: If resolved, is the statement true? 295 ifdef = False 296 if if_true: 297 if not defined: 298 outlines.append(line) 299 idx += 1 300 continue 301 302 defined_value = self._defs[macro] 303 is_int = True 304 try: 305 defined_value = int(defined_value) 306 except TypeError: 307 is_int = False 308 except ValueError: 309 is_int = False 310 311 resolved = is_int 312 is_true = (defined_value != 0) 313 314 if resolved and op is not None: 315 if op == '&&': 316 resolved = not is_true 317 else: 318 assert op == '||' 319 resolved = is_true 320 321 else: 322 ifdef = groups.get('not') is None 323 elseif = groups.get('elif') is not None 324 325 macro2 = groups.get('macro2') 326 cmp = groups.get('cmp') 327 value = groups.get('value') 328 openp = groups.get('openp') 329 closep = groups.get('closep') 330 331 is_true = (ifdef == defined) 332 resolved = True 333 if op is not None: 334 if op == '&&': 335 resolved = not is_true 336 else: 337 assert op == '||' 338 resolved = is_true 339 340 if macro2 is not None and not resolved: 341 assert ifdef and defined and op == '&&' and cmp is not None 342 # If the statment is true, but we have a single value check, then 343 # check the value. 344 defined_value = self._defs[macro] 345 are_ints = True 346 try: 347 defined_value = int(defined_value) 348 value = int(value) 349 except TypeError: 350 are_ints = False 351 except ValueError: 352 are_ints = False 353 if ( 354 macro == macro2 and 355 ((openp is None) == (closep is None)) and 356 are_ints 357 ): 358 resolved = True 359 if cmp == '<': 360 is_true = defined_value < value 361 elif cmp == '<=': 362 is_true = defined_value <= value 363 elif cmp == '==': 364 is_true = defined_value == value 365 elif cmp == '!=': 366 is_true = defined_value != value 367 elif cmp == '>=': 368 is_true = defined_value >= value 369 elif cmp == '>': 370 is_true = defined_value > value 371 else: 372 resolved = False 373 374 if op is not None and not resolved: 375 # Remove the first op in the line + spaces 376 if op == '&&': 377 opre = op 378 else: 379 assert op == '||' 380 opre = r'\|\|' 381 needle = re.compile(fr"(?P<if>\s*#\s*(el)?if\s+).*?(?P<op>{opre}\s*)") 382 match = needle.match(line) 383 assert match is not None 384 newline = line[:match.end('if')] + line[match.end('op'):] 385 386 self._log(f"\tHardwiring partially resolved {macro}") 387 self._log(f"\t\t- {line[:-1]}") 388 self._log(f"\t\t+ {newline[:-1]}") 389 390 outlines.append(newline) 391 idx += 1 392 continue 393 394 # Skip any statements we cannot fully compute 395 if not resolved: 396 outlines.append(line) 397 idx += 1 398 continue 399 400 prepend = [] 401 if macro in self._replaces: 402 assert not ifdef 403 assert op is None 404 value = self._replaces.pop(macro) 405 prepend = [f"#define {macro} {value}\n"] 406 407 idx, replace = self._handle_if_block(macro, idx, is_true, prepend) 408 outlines += replace 409 changed = True 410 411 return changed, outlines 412 413 def preprocess(self, filename): 414 with open(filename, 'r') as f: 415 self._inlines = f.readlines() 416 changed = True 417 iters = 0 418 while changed: 419 iters += 1 420 changed, outlines = self._preprocess_once() 421 self._inlines = outlines 422 423 with open(filename, 'w') as f: 424 f.write(''.join(self._inlines)) 425 426 427class Freestanding(object): 428 def __init__( 429 self, zstd_deps: str, mem: str, source_lib: str, output_lib: str, 430 external_xxhash: bool, xxh64_state: Optional[str], 431 xxh64_prefix: Optional[str], rewritten_includes: [(str, str)], 432 defs: [(str, Optional[str])], replaces: [(str, str)], 433 undefs: [str], excludes: [str] 434 ): 435 self._zstd_deps = zstd_deps 436 self._mem = mem 437 self._src_lib = source_lib 438 self._dst_lib = output_lib 439 self._external_xxhash = external_xxhash 440 self._xxh64_state = xxh64_state 441 self._xxh64_prefix = xxh64_prefix 442 self._rewritten_includes = rewritten_includes 443 self._defs = defs 444 self._replaces = replaces 445 self._undefs = undefs 446 self._excludes = excludes 447 448 def _dst_lib_file_paths(self): 449 """ 450 Yields all the file paths in the dst_lib. 451 """ 452 for root, dirname, filenames in os.walk(self._dst_lib): 453 for filename in filenames: 454 filepath = os.path.join(root, filename) 455 yield filepath 456 457 def _log(self, *args, **kwargs): 458 print(*args, **kwargs) 459 460 def _copy_file(self, lib_path): 461 if not (lib_path.endswith(".c") or lib_path.endswith(".h")): 462 return 463 if lib_path in SKIPPED_FILES: 464 self._log(f"\tSkipping file: {lib_path}") 465 return 466 if self._external_xxhash and lib_path in XXHASH_FILES: 467 self._log(f"\tSkipping xxhash file: {lib_path}") 468 return 469 470 src_path = os.path.join(self._src_lib, lib_path) 471 dst_path = os.path.join(self._dst_lib, lib_path) 472 self._log(f"\tCopying: {src_path} -> {dst_path}") 473 shutil.copyfile(src_path, dst_path) 474 475 def _copy_source_lib(self): 476 self._log("Copying source library into output library") 477 478 assert os.path.exists(self._src_lib) 479 os.makedirs(self._dst_lib, exist_ok=True) 480 self._copy_file("zstd.h") 481 for subdir in INCLUDED_SUBDIRS: 482 src_dir = os.path.join(self._src_lib, subdir) 483 dst_dir = os.path.join(self._dst_lib, subdir) 484 485 assert os.path.exists(src_dir) 486 os.makedirs(dst_dir, exist_ok=True) 487 488 for filename in os.listdir(src_dir): 489 lib_path = os.path.join(subdir, filename) 490 self._copy_file(lib_path) 491 492 def _copy_zstd_deps(self): 493 dst_zstd_deps = os.path.join(self._dst_lib, "common", "zstd_deps.h") 494 self._log(f"Copying zstd_deps: {self._zstd_deps} -> {dst_zstd_deps}") 495 shutil.copyfile(self._zstd_deps, dst_zstd_deps) 496 497 def _copy_mem(self): 498 dst_mem = os.path.join(self._dst_lib, "common", "mem.h") 499 self._log(f"Copying mem: {self._mem} -> {dst_mem}") 500 shutil.copyfile(self._mem, dst_mem) 501 502 def _hardwire_preprocessor(self, name: str, value: Optional[str] = None, undef=False): 503 """ 504 If value=None then hardwire that it is defined, but not what the value is. 505 If undef=True then value must be None. 506 If value='' then the macro is defined to '' exactly. 507 """ 508 assert not (undef and value is not None) 509 for filepath in self._dst_lib_file_paths(): 510 file = FileLines(filepath) 511 512 def _hardwire_defines(self): 513 self._log("Hardwiring macros") 514 partial_preprocessor = PartialPreprocessor(self._defs, self._replaces, self._undefs) 515 for filepath in self._dst_lib_file_paths(): 516 partial_preprocessor.preprocess(filepath) 517 518 def _remove_excludes(self): 519 self._log("Removing excluded sections") 520 for exclude in self._excludes: 521 self._log(f"\tRemoving excluded sections for: {exclude}") 522 begin_re = re.compile(f"BEGIN {exclude}") 523 end_re = re.compile(f"END {exclude}") 524 for filepath in self._dst_lib_file_paths(): 525 file = FileLines(filepath) 526 outlines = [] 527 skipped = [] 528 emit = True 529 for line in file.lines: 530 if emit and begin_re.search(line) is not None: 531 assert end_re.search(line) is None 532 emit = False 533 if emit: 534 outlines.append(line) 535 else: 536 skipped.append(line) 537 if end_re.search(line) is not None: 538 assert begin_re.search(line) is None 539 self._log(f"\t\tRemoving excluded section: {exclude}") 540 for s in skipped: 541 self._log(f"\t\t\t- {s}") 542 emit = True 543 skipped = [] 544 if not emit: 545 raise RuntimeError("Excluded section unfinished!") 546 file.lines = outlines 547 file.write() 548 549 def _rewrite_include(self, original, rewritten): 550 self._log(f"\tRewriting include: {original} -> {rewritten}") 551 regex = re.compile(f"\\s*#\\s*include\\s*(?P<include>{original})") 552 for filepath in self._dst_lib_file_paths(): 553 file = FileLines(filepath) 554 for i, line in enumerate(file.lines): 555 match = regex.match(line) 556 if match is None: 557 continue 558 s = match.start('include') 559 e = match.end('include') 560 file.lines[i] = line[:s] + rewritten + line[e:] 561 file.write() 562 563 def _rewrite_includes(self): 564 self._log("Rewriting includes") 565 for original, rewritten in self._rewritten_includes: 566 self._rewrite_include(original, rewritten) 567 568 def _replace_xxh64_prefix(self): 569 if self._xxh64_prefix is None: 570 return 571 self._log(f"Replacing XXH64 prefix with {self._xxh64_prefix}") 572 replacements = [] 573 if self._xxh64_state is not None: 574 replacements.append( 575 (re.compile(r"([^\w]|^)(?P<orig>XXH64_state_t)([^\w]|$)"), self._xxh64_state) 576 ) 577 if self._xxh64_prefix is not None: 578 replacements.append( 579 (re.compile(r"([^\w]|^)(?P<orig>XXH64)_"), self._xxh64_prefix) 580 ) 581 for filepath in self._dst_lib_file_paths(): 582 file = FileLines(filepath) 583 for i, line in enumerate(file.lines): 584 modified = False 585 for regex, replacement in replacements: 586 match = regex.search(line) 587 while match is not None: 588 modified = True 589 b = match.start('orig') 590 e = match.end('orig') 591 line = line[:b] + replacement + line[e:] 592 match = regex.search(line) 593 if modified: 594 self._log(f"\t- {file.lines[i][:-1]}") 595 self._log(f"\t+ {line[:-1]}") 596 file.lines[i] = line 597 file.write() 598 599 def go(self): 600 self._copy_source_lib() 601 self._copy_zstd_deps() 602 self._copy_mem() 603 self._hardwire_defines() 604 self._remove_excludes() 605 self._rewrite_includes() 606 self._replace_xxh64_prefix() 607 608 609def parse_optional_pair(defines: [str]) -> [(str, Optional[str])]: 610 output = [] 611 for define in defines: 612 parsed = define.split('=') 613 if len(parsed) == 1: 614 output.append((parsed[0], None)) 615 elif len(parsed) == 2: 616 output.append((parsed[0], parsed[1])) 617 else: 618 raise RuntimeError(f"Bad define: {define}") 619 return output 620 621 622def parse_pair(rewritten_includes: [str]) -> [(str, str)]: 623 output = [] 624 for rewritten_include in rewritten_includes: 625 parsed = rewritten_include.split('=') 626 if len(parsed) == 2: 627 output.append((parsed[0], parsed[1])) 628 else: 629 raise RuntimeError(f"Bad rewritten include: {rewritten_include}") 630 return output 631 632 633 634def main(name, args): 635 parser = argparse.ArgumentParser(prog=name) 636 parser.add_argument("--zstd-deps", default="zstd_deps.h", help="Zstd dependencies file") 637 parser.add_argument("--mem", default="mem.h", help="Memory module") 638 parser.add_argument("--source-lib", default="../../lib", help="Location of the zstd library") 639 parser.add_argument("--output-lib", default="./freestanding_lib", help="Where to output the freestanding zstd library") 640 parser.add_argument("--xxhash", default=None, help="Alternate external xxhash include e.g. --xxhash='<xxhash.h>'. If set xxhash is not included.") 641 parser.add_argument("--xxh64-state", default=None, help="Alternate XXH64 state type (excluding _) e.g. --xxh64-state='struct xxh64_state'") 642 parser.add_argument("--xxh64-prefix", default=None, help="Alternate XXH64 function prefix (excluding _) e.g. --xxh64-prefix=xxh64") 643 parser.add_argument("--rewrite-include", default=[], dest="rewritten_includes", action="append", help="Rewrite an include REGEX=NEW (e.g. '<stddef\\.h>=<linux/types.h>')") 644 parser.add_argument("-D", "--define", default=[], dest="defs", action="append", help="Pre-define this macro (can be passed multiple times)") 645 parser.add_argument("-U", "--undefine", default=[], dest="undefs", action="append", help="Pre-undefine this macro (can be passed mutliple times)") 646 parser.add_argument("-R", "--replace", default=[], dest="replaces", action="append", help="Pre-define this macro and replace the first ifndef block with its definition") 647 parser.add_argument("-E", "--exclude", default=[], dest="excludes", action="append", help="Exclude all lines between 'BEGIN <EXCLUDE>' and 'END <EXCLUDE>'") 648 args = parser.parse_args(args) 649 650 # Always remove threading 651 if "ZSTD_MULTITHREAD" not in args.undefs: 652 args.undefs.append("ZSTD_MULTITHREAD") 653 654 args.defs = parse_optional_pair(args.defs) 655 for name, _ in args.defs: 656 if name in args.undefs: 657 raise RuntimeError(f"{name} is both defined and undefined!") 658 659 args.replaces = parse_pair(args.replaces) 660 for name, _ in args.replaces: 661 if name in args.undefs or name in args.defs: 662 raise RuntimeError(f"{name} is both replaced and (un)defined!") 663 664 args.rewritten_includes = parse_pair(args.rewritten_includes) 665 666 external_xxhash = False 667 if args.xxhash is not None: 668 external_xxhash = True 669 args.rewritten_includes.append(('"(\\.\\./common/)?xxhash.h"', args.xxhash)) 670 671 if args.xxh64_prefix is not None: 672 if not external_xxhash: 673 raise RuntimeError("--xxh64-prefix may only be used with --xxhash provided") 674 675 if args.xxh64_state is not None: 676 if not external_xxhash: 677 raise RuntimeError("--xxh64-state may only be used with --xxhash provided") 678 679 Freestanding( 680 args.zstd_deps, 681 args.mem, 682 args.source_lib, 683 args.output_lib, 684 external_xxhash, 685 args.xxh64_state, 686 args.xxh64_prefix, 687 args.rewritten_includes, 688 args.defs, 689 args.replaces, 690 args.undefs, 691 args.excludes 692 ).go() 693 694if __name__ == "__main__": 695 main(sys.argv[0], sys.argv[1:]) 696