1import token
2from typing import Any, Dict, Optional, IO, Text, Tuple
3
4from pegen.grammar import (
5    Cut,
6    GrammarVisitor,
7    NameLeaf,
8    StringLeaf,
9    Rhs,
10    NamedItem,
11    Lookahead,
12    PositiveLookahead,
13    NegativeLookahead,
14    Opt,
15    Repeat0,
16    Repeat1,
17    Gather,
18    Group,
19    Rule,
20    Alt,
21)
22from pegen import grammar
23from pegen.parser_generator import ParserGenerator
24
25MODULE_PREFIX = """\
26#!/usr/bin/env python3.8
27# @generated by pegen from {filename}
28
29import ast
30import sys
31import tokenize
32
33from typing import Any, Optional
34
35from pegen.parser import memoize, memoize_left_rec, logger, Parser
36
37"""
38MODULE_SUFFIX = """
39
40if __name__ == '__main__':
41    from pegen.parser import simple_parser_main
42    simple_parser_main(GeneratedParser)
43"""
44
45
46class PythonCallMakerVisitor(GrammarVisitor):
47    def __init__(self, parser_generator: ParserGenerator):
48        self.gen = parser_generator
49        self.cache: Dict[Any, Any] = {}
50
51    def visit_NameLeaf(self, node: NameLeaf) -> Tuple[Optional[str], str]:
52        name = node.value
53        if name in ("NAME", "NUMBER", "STRING", "OP"):
54            name = name.lower()
55            return name, f"self.{name}()"
56        if name in ("NEWLINE", "DEDENT", "INDENT", "ENDMARKER", "ASYNC", "AWAIT"):
57            return name.lower(), f"self.expect({name!r})"
58        return name, f"self.{name}()"
59
60    def visit_StringLeaf(self, node: StringLeaf) -> Tuple[str, str]:
61        return "literal", f"self.expect({node.value})"
62
63    def visit_Rhs(self, node: Rhs) -> Tuple[Optional[str], str]:
64        if node in self.cache:
65            return self.cache[node]
66        if len(node.alts) == 1 and len(node.alts[0].items) == 1:
67            self.cache[node] = self.visit(node.alts[0].items[0])
68        else:
69            name = self.gen.name_node(node)
70            self.cache[node] = name, f"self.{name}()"
71        return self.cache[node]
72
73    def visit_NamedItem(self, node: NamedItem) -> Tuple[Optional[str], str]:
74        name, call = self.visit(node.item)
75        if node.name:
76            name = node.name
77        return name, call
78
79    def lookahead_call_helper(self, node: Lookahead) -> Tuple[str, str]:
80        name, call = self.visit(node.node)
81        head, tail = call.split("(", 1)
82        assert tail[-1] == ")"
83        tail = tail[:-1]
84        return head, tail
85
86    def visit_PositiveLookahead(self, node: PositiveLookahead) -> Tuple[None, str]:
87        head, tail = self.lookahead_call_helper(node)
88        return None, f"self.positive_lookahead({head}, {tail})"
89
90    def visit_NegativeLookahead(self, node: NegativeLookahead) -> Tuple[None, str]:
91        head, tail = self.lookahead_call_helper(node)
92        return None, f"self.negative_lookahead({head}, {tail})"
93
94    def visit_Opt(self, node: Opt) -> Tuple[str, str]:
95        name, call = self.visit(node.node)
96        # Note trailing comma (the call may already have one comma
97        # at the end, for example when rules have both repeat0 and optional
98        # markers, e.g: [rule*])
99        if call.endswith(","):
100            return "opt", call
101        else:
102            return "opt", f"{call},"
103
104    def visit_Repeat0(self, node: Repeat0) -> Tuple[str, str]:
105        if node in self.cache:
106            return self.cache[node]
107        name = self.gen.name_loop(node.node, False)
108        self.cache[node] = name, f"self.{name}(),"  # Also a trailing comma!
109        return self.cache[node]
110
111    def visit_Repeat1(self, node: Repeat1) -> Tuple[str, str]:
112        if node in self.cache:
113            return self.cache[node]
114        name = self.gen.name_loop(node.node, True)
115        self.cache[node] = name, f"self.{name}()"  # But no trailing comma here!
116        return self.cache[node]
117
118    def visit_Gather(self, node: Gather) -> Tuple[str, str]:
119        if node in self.cache:
120            return self.cache[node]
121        name = self.gen.name_gather(node)
122        self.cache[node] = name, f"self.{name}()"  # No trailing comma here either!
123        return self.cache[node]
124
125    def visit_Group(self, node: Group) -> Tuple[Optional[str], str]:
126        return self.visit(node.rhs)
127
128    def visit_Cut(self, node: Cut) -> Tuple[str, str]:
129        return "cut", "True"
130
131
132class PythonParserGenerator(ParserGenerator, GrammarVisitor):
133    def __init__(
134        self,
135        grammar: grammar.Grammar,
136        file: Optional[IO[Text]],
137        tokens: Dict[int, str] = token.tok_name,
138    ):
139        super().__init__(grammar, tokens, file)
140        self.callmakervisitor = PythonCallMakerVisitor(self)
141
142    def generate(self, filename: str) -> None:
143        header = self.grammar.metas.get("header", MODULE_PREFIX)
144        if header is not None:
145            self.print(header.rstrip("\n").format(filename=filename))
146        subheader = self.grammar.metas.get("subheader", "")
147        if subheader:
148            self.print(subheader.format(filename=filename))
149        self.print("class GeneratedParser(Parser):")
150        while self.todo:
151            for rulename, rule in list(self.todo.items()):
152                del self.todo[rulename]
153                self.print()
154                with self.indent():
155                    self.visit(rule)
156        trailer = self.grammar.metas.get("trailer", MODULE_SUFFIX)
157        if trailer is not None:
158            self.print(trailer.rstrip("\n"))
159
160    def visit_Rule(self, node: Rule) -> None:
161        is_loop = node.is_loop()
162        is_gather = node.is_gather()
163        rhs = node.flatten()
164        if node.left_recursive:
165            if node.leader:
166                self.print("@memoize_left_rec")
167            else:
168                # Non-leader rules in a cycle are not memoized,
169                # but they must still be logged.
170                self.print("@logger")
171        else:
172            self.print("@memoize")
173        node_type = node.type or "Any"
174        self.print(f"def {node.name}(self) -> Optional[{node_type}]:")
175        with self.indent():
176            self.print(f"# {node.name}: {rhs}")
177            if node.nullable:
178                self.print(f"# nullable={node.nullable}")
179            self.print("mark = self.mark()")
180            if is_loop:
181                self.print("children = []")
182            self.visit(rhs, is_loop=is_loop, is_gather=is_gather)
183            if is_loop:
184                self.print("return children")
185            else:
186                self.print("return None")
187
188    def visit_NamedItem(self, node: NamedItem) -> None:
189        name, call = self.callmakervisitor.visit(node.item)
190        if node.name:
191            name = node.name
192        if not name:
193            self.print(call)
194        else:
195            if name != "cut":
196                name = self.dedupe(name)
197            self.print(f"({name} := {call})")
198
199    def visit_Rhs(self, node: Rhs, is_loop: bool = False, is_gather: bool = False) -> None:
200        if is_loop:
201            assert len(node.alts) == 1
202        for alt in node.alts:
203            self.visit(alt, is_loop=is_loop, is_gather=is_gather)
204
205    def visit_Alt(self, node: Alt, is_loop: bool, is_gather: bool) -> None:
206        with self.local_variable_context():
207            self.print("cut = False")  # TODO: Only if needed.
208            if is_loop:
209                self.print("while (")
210            else:
211                self.print("if (")
212            with self.indent():
213                first = True
214                for item in node.items:
215                    if first:
216                        first = False
217                    else:
218                        self.print("and")
219                    self.visit(item)
220                    if is_gather:
221                        self.print("is not None")
222
223            self.print("):")
224            with self.indent():
225                action = node.action
226                if not action:
227                    if is_gather:
228                        assert len(self.local_variable_names) == 2
229                        action = (
230                            f"[{self.local_variable_names[0]}] + {self.local_variable_names[1]}"
231                        )
232                    else:
233                        action = f"[{', '.join(self.local_variable_names)}]"
234                if is_loop:
235                    self.print(f"children.append({action})")
236                    self.print(f"mark = self.mark()")
237                else:
238                    self.print(f"return {action}")
239            self.print("self.reset(mark)")
240            # Skip remaining alternatives if a cut was reached.
241            self.print("if cut: return None")  # TODO: Only if needed.
242