1from __future__ import annotations
2
3from abc import abstractmethod
4from typing import (
5    AbstractSet,
6    Any,
7    Dict,
8    Iterable,
9    Iterator,
10    List,
11    Optional,
12    Set,
13    Tuple,
14    TYPE_CHECKING,
15    Union,
16)
17
18
19if TYPE_CHECKING:
20    from pegen.parser_generator import ParserGenerator
21
22
23class GrammarError(Exception):
24    pass
25
26
27class GrammarVisitor:
28    def visit(self, node: Any, *args: Any, **kwargs: Any) -> Any:
29        """Visit a node."""
30        method = "visit_" + node.__class__.__name__
31        visitor = getattr(self, method, self.generic_visit)
32        return visitor(node, *args, **kwargs)
33
34    def generic_visit(self, node: Iterable[Any], *args: Any, **kwargs: Any) -> None:
35        """Called if no explicit visitor function exists for a node."""
36        for value in node:
37            if isinstance(value, list):
38                for item in value:
39                    self.visit(item, *args, **kwargs)
40            else:
41                self.visit(value, *args, **kwargs)
42
43
44class Grammar:
45    def __init__(self, rules: Iterable[Rule], metas: Iterable[Tuple[str, Optional[str]]]):
46        self.rules = {rule.name: rule for rule in rules}
47        self.metas = dict(metas)
48
49    def __str__(self) -> str:
50        return "\n".join(str(rule) for name, rule in self.rules.items())
51
52    def __repr__(self) -> str:
53        lines = ["Grammar("]
54        lines.append("  [")
55        for rule in self.rules.values():
56            lines.append(f"    {repr(rule)},")
57        lines.append("  ],")
58        lines.append("  {repr(list(self.metas.items()))}")
59        lines.append(")")
60        return "\n".join(lines)
61
62    def __iter__(self) -> Iterator[Rule]:
63        yield from self.rules.values()
64
65
66# Global flag whether we want actions in __str__() -- default off.
67SIMPLE_STR = True
68
69
70class Rule:
71    def __init__(self, name: str, type: Optional[str], rhs: Rhs, memo: Optional[object] = None):
72        self.name = name
73        self.type = type
74        self.rhs = rhs
75        self.memo = bool(memo)
76        self.visited = False
77        self.nullable = False
78        self.left_recursive = False
79        self.leader = False
80
81    def is_loop(self) -> bool:
82        return self.name.startswith("_loop")
83
84    def is_gather(self) -> bool:
85        return self.name.startswith("_gather")
86
87    def __str__(self) -> str:
88        if SIMPLE_STR or self.type is None:
89            res = f"{self.name}: {self.rhs}"
90        else:
91            res = f"{self.name}[{self.type}]: {self.rhs}"
92        if len(res) < 88:
93            return res
94        lines = [res.split(":")[0] + ":"]
95        lines += [f"    | {alt}" for alt in self.rhs.alts]
96        return "\n".join(lines)
97
98    def __repr__(self) -> str:
99        return f"Rule({self.name!r}, {self.type!r}, {self.rhs!r})"
100
101    def __iter__(self) -> Iterator[Rhs]:
102        yield self.rhs
103
104    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
105        if self.visited:
106            # A left-recursive rule is considered non-nullable.
107            return False
108        self.visited = True
109        self.nullable = self.rhs.nullable_visit(rules)
110        return self.nullable
111
112    def initial_names(self) -> AbstractSet[str]:
113        return self.rhs.initial_names()
114
115    def flatten(self) -> Rhs:
116        # If it's a single parenthesized group, flatten it.
117        rhs = self.rhs
118        if (
119            not self.is_loop()
120            and len(rhs.alts) == 1
121            and len(rhs.alts[0].items) == 1
122            and isinstance(rhs.alts[0].items[0].item, Group)
123        ):
124            rhs = rhs.alts[0].items[0].item.rhs
125        return rhs
126
127    def collect_todo(self, gen: ParserGenerator) -> None:
128        rhs = self.flatten()
129        rhs.collect_todo(gen)
130
131
132class Leaf:
133    def __init__(self, value: str):
134        self.value = value
135
136    def __str__(self) -> str:
137        return self.value
138
139    def __iter__(self) -> Iterable[str]:
140        if False:
141            yield
142
143    @abstractmethod
144    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
145        raise NotImplementedError
146
147    @abstractmethod
148    def initial_names(self) -> AbstractSet[str]:
149        raise NotImplementedError
150
151
152class NameLeaf(Leaf):
153    """The value is the name."""
154
155    def __str__(self) -> str:
156        if self.value == "ENDMARKER":
157            return "$"
158        return super().__str__()
159
160    def __repr__(self) -> str:
161        return f"NameLeaf({self.value!r})"
162
163    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
164        if self.value in rules:
165            return rules[self.value].nullable_visit(rules)
166        # Token or unknown; never empty.
167        return False
168
169    def initial_names(self) -> AbstractSet[str]:
170        return {self.value}
171
172
173class StringLeaf(Leaf):
174    """The value is a string literal, including quotes."""
175
176    def __repr__(self) -> str:
177        return f"StringLeaf({self.value!r})"
178
179    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
180        # The string token '' is considered empty.
181        return not self.value
182
183    def initial_names(self) -> AbstractSet[str]:
184        return set()
185
186
187class Rhs:
188    def __init__(self, alts: List[Alt]):
189        self.alts = alts
190        self.memo: Optional[Tuple[Optional[str], str]] = None
191
192    def __str__(self) -> str:
193        return " | ".join(str(alt) for alt in self.alts)
194
195    def __repr__(self) -> str:
196        return f"Rhs({self.alts!r})"
197
198    def __iter__(self) -> Iterator[List[Alt]]:
199        yield self.alts
200
201    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
202        for alt in self.alts:
203            if alt.nullable_visit(rules):
204                return True
205        return False
206
207    def initial_names(self) -> AbstractSet[str]:
208        names: Set[str] = set()
209        for alt in self.alts:
210            names |= alt.initial_names()
211        return names
212
213    def collect_todo(self, gen: ParserGenerator) -> None:
214        for alt in self.alts:
215            alt.collect_todo(gen)
216
217
218class Alt:
219    def __init__(self, items: List[NamedItem], *, icut: int = -1, action: Optional[str] = None):
220        self.items = items
221        self.icut = icut
222        self.action = action
223
224    def __str__(self) -> str:
225        core = " ".join(str(item) for item in self.items)
226        if not SIMPLE_STR and self.action:
227            return f"{core} {{ {self.action} }}"
228        else:
229            return core
230
231    def __repr__(self) -> str:
232        args = [repr(self.items)]
233        if self.icut >= 0:
234            args.append(f"icut={self.icut}")
235        if self.action:
236            args.append(f"action={self.action!r}")
237        return f"Alt({', '.join(args)})"
238
239    def __iter__(self) -> Iterator[List[NamedItem]]:
240        yield self.items
241
242    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
243        for item in self.items:
244            if not item.nullable_visit(rules):
245                return False
246        return True
247
248    def initial_names(self) -> AbstractSet[str]:
249        names: Set[str] = set()
250        for item in self.items:
251            names |= item.initial_names()
252            if not item.nullable:
253                break
254        return names
255
256    def collect_todo(self, gen: ParserGenerator) -> None:
257        for item in self.items:
258            item.collect_todo(gen)
259
260
261class NamedItem:
262    def __init__(self, name: Optional[str], item: Item):
263        self.name = name
264        self.item = item
265        self.nullable = False
266
267    def __str__(self) -> str:
268        if not SIMPLE_STR and self.name:
269            return f"{self.name}={self.item}"
270        else:
271            return str(self.item)
272
273    def __repr__(self) -> str:
274        return f"NamedItem({self.name!r}, {self.item!r})"
275
276    def __iter__(self) -> Iterator[Item]:
277        yield self.item
278
279    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
280        self.nullable = self.item.nullable_visit(rules)
281        return self.nullable
282
283    def initial_names(self) -> AbstractSet[str]:
284        return self.item.initial_names()
285
286    def collect_todo(self, gen: ParserGenerator) -> None:
287        gen.callmakervisitor.visit(self.item)
288
289
290class Lookahead:
291    def __init__(self, node: Plain, sign: str):
292        self.node = node
293        self.sign = sign
294
295    def __str__(self) -> str:
296        return f"{self.sign}{self.node}"
297
298    def __iter__(self) -> Iterator[Plain]:
299        yield self.node
300
301    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
302        return True
303
304    def initial_names(self) -> AbstractSet[str]:
305        return set()
306
307
308class PositiveLookahead(Lookahead):
309    def __init__(self, node: Plain):
310        super().__init__(node, "&")
311
312    def __repr__(self) -> str:
313        return f"PositiveLookahead({self.node!r})"
314
315
316class NegativeLookahead(Lookahead):
317    def __init__(self, node: Plain):
318        super().__init__(node, "!")
319
320    def __repr__(self) -> str:
321        return f"NegativeLookahead({self.node!r})"
322
323
324class Opt:
325    def __init__(self, node: Item):
326        self.node = node
327
328    def __str__(self) -> str:
329        s = str(self.node)
330        # TODO: Decide whether to use [X] or X? based on type of X
331        if " " in s:
332            return f"[{s}]"
333        else:
334            return f"{s}?"
335
336    def __repr__(self) -> str:
337        return f"Opt({self.node!r})"
338
339    def __iter__(self) -> Iterator[Item]:
340        yield self.node
341
342    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
343        return True
344
345    def initial_names(self) -> AbstractSet[str]:
346        return self.node.initial_names()
347
348
349class Repeat:
350    """Shared base class for x* and x+."""
351
352    def __init__(self, node: Plain):
353        self.node = node
354        self.memo: Optional[Tuple[Optional[str], str]] = None
355
356    @abstractmethod
357    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
358        raise NotImplementedError
359
360    def __iter__(self) -> Iterator[Plain]:
361        yield self.node
362
363    def initial_names(self) -> AbstractSet[str]:
364        return self.node.initial_names()
365
366
367class Repeat0(Repeat):
368    def __str__(self) -> str:
369        s = str(self.node)
370        # TODO: Decide whether to use (X)* or X* based on type of X
371        if " " in s:
372            return f"({s})*"
373        else:
374            return f"{s}*"
375
376    def __repr__(self) -> str:
377        return f"Repeat0({self.node!r})"
378
379    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
380        return True
381
382
383class Repeat1(Repeat):
384    def __str__(self) -> str:
385        s = str(self.node)
386        # TODO: Decide whether to use (X)+ or X+ based on type of X
387        if " " in s:
388            return f"({s})+"
389        else:
390            return f"{s}+"
391
392    def __repr__(self) -> str:
393        return f"Repeat1({self.node!r})"
394
395    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
396        return False
397
398
399class Gather(Repeat):
400    def __init__(self, separator: Plain, node: Plain):
401        self.separator = separator
402        self.node = node
403
404    def __str__(self) -> str:
405        return f"{self.separator!s}.{self.node!s}+"
406
407    def __repr__(self) -> str:
408        return f"Gather({self.separator!r}, {self.node!r})"
409
410    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
411        return False
412
413
414class Group:
415    def __init__(self, rhs: Rhs):
416        self.rhs = rhs
417
418    def __str__(self) -> str:
419        return f"({self.rhs})"
420
421    def __repr__(self) -> str:
422        return f"Group({self.rhs!r})"
423
424    def __iter__(self) -> Iterator[Rhs]:
425        yield self.rhs
426
427    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
428        return self.rhs.nullable_visit(rules)
429
430    def initial_names(self) -> AbstractSet[str]:
431        return self.rhs.initial_names()
432
433
434class Cut:
435    def __init__(self) -> None:
436        pass
437
438    def __repr__(self) -> str:
439        return f"Cut()"
440
441    def __str__(self) -> str:
442        return f"~"
443
444    def __iter__(self) -> Iterator[Tuple[str, str]]:
445        if False:
446            yield
447
448    def __eq__(self, other: object) -> bool:
449        if not isinstance(other, Cut):
450            return NotImplemented
451        return True
452
453    def nullable_visit(self, rules: Dict[str, Rule]) -> bool:
454        return True
455
456    def initial_names(self) -> AbstractSet[str]:
457        return set()
458
459
460Plain = Union[Leaf, Group]
461Item = Union[Plain, Opt, Repeat, Lookahead, Rhs, Cut]
462RuleName = Tuple[str, str]
463MetaTuple = Tuple[str, Optional[str]]
464MetaList = List[MetaTuple]
465RuleList = List[Rule]
466NamedItemList = List[NamedItem]
467LookaheadOrCut = Union[Lookahead, Cut]
468