1import contextlib
2from abc import abstractmethod
3
4from typing import AbstractSet, Dict, IO, Iterator, List, Optional, Set, Text, Tuple
5
6from pegen import sccutils
7from pegen.grammar import (
8    Grammar,
9    Rule,
10    Rhs,
11    Alt,
12    NamedItem,
13    Plain,
14    NameLeaf,
15    Gather,
16)
17from pegen.grammar import GrammarError, GrammarVisitor
18
19
20class RuleCheckingVisitor(GrammarVisitor):
21    def __init__(self, rules: Dict[str, Rule], tokens: Dict[int, str]):
22        self.rules = rules
23        self.tokens = tokens
24
25    def visit_NameLeaf(self, node: NameLeaf) -> None:
26        if node.value not in self.rules and node.value not in self.tokens.values():
27            # TODO: Add line/col info to (leaf) nodes
28            raise GrammarError(f"Dangling reference to rule {node.value!r}")
29
30    def visit_NamedItem(self, node: NamedItem) -> None:
31        if node.name and node.name.startswith("_"):
32            raise GrammarError(f"Variable names cannot start with underscore: '{node.name}'")
33        self.visit(node.item)
34
35
36class ParserGenerator:
37
38    callmakervisitor: GrammarVisitor
39
40    def __init__(self, grammar: Grammar, tokens: Dict[int, str], file: Optional[IO[Text]]):
41        self.grammar = grammar
42        self.tokens = tokens
43        self.rules = grammar.rules
44        self.validate_rule_names()
45        if "trailer" not in grammar.metas and "start" not in self.rules:
46            raise GrammarError("Grammar without a trailer must have a 'start' rule")
47        checker = RuleCheckingVisitor(self.rules, self.tokens)
48        for rule in self.rules.values():
49            checker.visit(rule)
50        self.file = file
51        self.level = 0
52        compute_nullables(self.rules)
53        self.first_graph, self.first_sccs = compute_left_recursives(self.rules)
54        self.todo = self.rules.copy()  # Rules to generate
55        self.counter = 0  # For name_rule()/name_loop()
56        self.keyword_counter = 499  # For keyword_type()
57        self.all_rules: Dict[str, Rule] = {}  # Rules + temporal rules
58        self._local_variable_stack: List[List[str]] = []
59
60    def validate_rule_names(self) -> None:
61        for rule in self.rules:
62            if rule.startswith("_"):
63                raise GrammarError(f"Rule names cannot start with underscore: '{rule}'")
64
65    @contextlib.contextmanager
66    def local_variable_context(self) -> Iterator[None]:
67        self._local_variable_stack.append([])
68        yield
69        self._local_variable_stack.pop()
70
71    @property
72    def local_variable_names(self) -> List[str]:
73        return self._local_variable_stack[-1]
74
75    @abstractmethod
76    def generate(self, filename: str) -> None:
77        raise NotImplementedError
78
79    @contextlib.contextmanager
80    def indent(self) -> Iterator[None]:
81        self.level += 1
82        try:
83            yield
84        finally:
85            self.level -= 1
86
87    def print(self, *args: object) -> None:
88        if not args:
89            print(file=self.file)
90        else:
91            print("    " * self.level, end="", file=self.file)
92            print(*args, file=self.file)
93
94    def printblock(self, lines: str) -> None:
95        for line in lines.splitlines():
96            self.print(line)
97
98    def collect_todo(self) -> None:
99        done: Set[str] = set()
100        while True:
101            alltodo = list(self.todo)
102            self.all_rules.update(self.todo)
103            todo = [i for i in alltodo if i not in done]
104            if not todo:
105                break
106            for rulename in todo:
107                self.todo[rulename].collect_todo(self)
108            done = set(alltodo)
109
110    def keyword_type(self) -> int:
111        self.keyword_counter += 1
112        return self.keyword_counter
113
114    def name_node(self, rhs: Rhs) -> str:
115        self.counter += 1
116        name = f"_tmp_{self.counter}"  # TODO: Pick a nicer name.
117        self.todo[name] = Rule(name, None, rhs)
118        return name
119
120    def name_loop(self, node: Plain, is_repeat1: bool) -> str:
121        self.counter += 1
122        if is_repeat1:
123            prefix = "_loop1_"
124        else:
125            prefix = "_loop0_"
126        name = f"{prefix}{self.counter}"  # TODO: It's ugly to signal via the name.
127        self.todo[name] = Rule(name, None, Rhs([Alt([NamedItem(None, node)])]))
128        return name
129
130    def name_gather(self, node: Gather) -> str:
131        self.counter += 1
132        name = f"_gather_{self.counter}"
133        self.counter += 1
134        extra_function_name = f"_loop0_{self.counter}"
135        extra_function_alt = Alt(
136            [NamedItem(None, node.separator), NamedItem("elem", node.node)], action="elem",
137        )
138        self.todo[extra_function_name] = Rule(
139            extra_function_name, None, Rhs([extra_function_alt]),
140        )
141        alt = Alt([NamedItem("elem", node.node), NamedItem("seq", NameLeaf(extra_function_name))],)
142        self.todo[name] = Rule(name, None, Rhs([alt]),)
143        return name
144
145    def dedupe(self, name: str) -> str:
146        origname = name
147        counter = 0
148        while name in self.local_variable_names:
149            counter += 1
150            name = f"{origname}_{counter}"
151        self.local_variable_names.append(name)
152        return name
153
154
155def compute_nullables(rules: Dict[str, Rule]) -> None:
156    """Compute which rules in a grammar are nullable.
157
158    Thanks to TatSu (tatsu/leftrec.py) for inspiration.
159    """
160    for rule in rules.values():
161        rule.nullable_visit(rules)
162
163
164def compute_left_recursives(
165    rules: Dict[str, Rule]
166) -> Tuple[Dict[str, AbstractSet[str]], List[AbstractSet[str]]]:
167    graph = make_first_graph(rules)
168    sccs = list(sccutils.strongly_connected_components(graph.keys(), graph))
169    for scc in sccs:
170        if len(scc) > 1:
171            for name in scc:
172                rules[name].left_recursive = True
173            # Try to find a leader such that all cycles go through it.
174            leaders = set(scc)
175            for start in scc:
176                for cycle in sccutils.find_cycles_in_scc(graph, scc, start):
177                    # print("Cycle:", " -> ".join(cycle))
178                    leaders -= scc - set(cycle)
179                    if not leaders:
180                        raise ValueError(
181                            f"SCC {scc} has no leadership candidate (no element is included in all cycles)"
182                        )
183            # print("Leaders:", leaders)
184            leader = min(leaders)  # Pick an arbitrary leader from the candidates.
185            rules[leader].leader = True
186        else:
187            name = min(scc)  # The only element.
188            if name in graph[name]:
189                rules[name].left_recursive = True
190                rules[name].leader = True
191    return graph, sccs
192
193
194def make_first_graph(rules: Dict[str, Rule]) -> Dict[str, AbstractSet[str]]:
195    """Compute the graph of left-invocations.
196
197    There's an edge from A to B if A may invoke B at its initial
198    position.
199
200    Note that this requires the nullable flags to have been computed.
201    """
202    graph = {}
203    vertices: Set[str] = set()
204    for rulename, rhs in rules.items():
205        graph[rulename] = names = rhs.initial_names()
206        vertices |= names
207    for vertex in vertices:
208        graph.setdefault(vertex, set())
209    return graph
210