1import io
2import textwrap
3import unittest
4
5from test import test_tools
6from typing import Dict, Any
7from tokenize import TokenInfo, NAME, NEWLINE, NUMBER, OP
8
9test_tools.skip_if_missing('peg_generator')
10with test_tools.imports_under_tool('peg_generator'):
11    from pegen.grammar_parser import GeneratedParser as GrammarParser
12    from pegen.testutil import (
13        parse_string,
14        generate_parser,
15        make_parser
16    )
17    from pegen.grammar import GrammarVisitor, GrammarError, Grammar
18    from pegen.grammar_visualizer import ASTGrammarPrinter
19    from pegen.parser import Parser
20    from pegen.python_generator import PythonParserGenerator
21
22
23class TestPegen(unittest.TestCase):
24    def test_parse_grammar(self) -> None:
25        grammar_source = """
26        start: sum NEWLINE
27        sum: t1=term '+' t2=term { action } | term
28        term: NUMBER
29        """
30        expected = """
31        start: sum NEWLINE
32        sum: term '+' term | term
33        term: NUMBER
34        """
35        grammar: Grammar = parse_string(grammar_source, GrammarParser)
36        rules = grammar.rules
37        self.assertEqual(str(grammar), textwrap.dedent(expected).strip())
38        # Check the str() and repr() of a few rules; AST nodes don't support ==.
39        self.assertEqual(str(rules["start"]), "start: sum NEWLINE")
40        self.assertEqual(str(rules["sum"]), "sum: term '+' term | term")
41        expected_repr = "Rule('term', None, Rhs([Alt([NamedItem(None, NameLeaf('NUMBER'))])]))"
42        self.assertEqual(repr(rules["term"]), expected_repr)
43
44    def test_long_rule_str(self) -> None:
45        grammar_source = """
46        start: zero | one | one zero | one one | one zero zero | one zero one | one one zero | one one one
47        """
48        expected = """
49        start:
50            | zero
51            | one
52            | one zero
53            | one one
54            | one zero zero
55            | one zero one
56            | one one zero
57            | one one one
58        """
59        grammar: Grammar = parse_string(grammar_source, GrammarParser)
60        self.assertEqual(str(grammar.rules["start"]), textwrap.dedent(expected).strip())
61
62    def test_typed_rules(self) -> None:
63        grammar = """
64        start[int]: sum NEWLINE
65        sum[int]: t1=term '+' t2=term { action } | term
66        term[int]: NUMBER
67        """
68        rules = parse_string(grammar, GrammarParser).rules
69        # Check the str() and repr() of a few rules; AST nodes don't support ==.
70        self.assertEqual(str(rules["start"]), "start: sum NEWLINE")
71        self.assertEqual(str(rules["sum"]), "sum: term '+' term | term")
72        self.assertEqual(
73            repr(rules["term"]),
74            "Rule('term', 'int', Rhs([Alt([NamedItem(None, NameLeaf('NUMBER'))])]))"
75        )
76
77    def test_gather(self) -> None:
78        grammar = """
79        start: ','.thing+ NEWLINE
80        thing: NUMBER
81        """
82        rules = parse_string(grammar, GrammarParser).rules
83        self.assertEqual(str(rules["start"]), "start: ','.thing+ NEWLINE")
84        self.assertTrue(repr(rules["start"]).startswith(
85            "Rule('start', None, Rhs([Alt([NamedItem(None, Gather(StringLeaf(\"','\"), NameLeaf('thing'"
86        ))
87        self.assertEqual(str(rules["thing"]), "thing: NUMBER")
88        parser_class = make_parser(grammar)
89        node = parse_string("42\n", parser_class)
90        assert node == [
91            [[TokenInfo(NUMBER, string="42", start=(1, 0), end=(1, 2), line="42\n")]],
92            TokenInfo(NEWLINE, string="\n", start=(1, 2), end=(1, 3), line="42\n"),
93        ]
94        node = parse_string("1, 2\n", parser_class)
95        assert node == [
96            [
97                [TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1, 2\n")],
98                [TokenInfo(NUMBER, string="2", start=(1, 3), end=(1, 4), line="1, 2\n")],
99            ],
100            TokenInfo(NEWLINE, string="\n", start=(1, 4), end=(1, 5), line="1, 2\n"),
101        ]
102
103    def test_expr_grammar(self) -> None:
104        grammar = """
105        start: sum NEWLINE
106        sum: term '+' term | term
107        term: NUMBER
108        """
109        parser_class = make_parser(grammar)
110        node = parse_string("42\n", parser_class)
111        self.assertEqual(node, [
112            [[TokenInfo(NUMBER, string="42", start=(1, 0), end=(1, 2), line="42\n")]],
113            TokenInfo(NEWLINE, string="\n", start=(1, 2), end=(1, 3), line="42\n"),
114        ])
115
116    def test_optional_operator(self) -> None:
117        grammar = """
118        start: sum NEWLINE
119        sum: term ('+' term)?
120        term: NUMBER
121        """
122        parser_class = make_parser(grammar)
123        node = parse_string("1+2\n", parser_class)
124        self.assertEqual(node, [
125            [
126                [TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1+2\n")],
127                [
128                    TokenInfo(OP, string="+", start=(1, 1), end=(1, 2), line="1+2\n"),
129                    [TokenInfo(NUMBER, string="2", start=(1, 2), end=(1, 3), line="1+2\n")],
130                ],
131            ],
132            TokenInfo(NEWLINE, string="\n", start=(1, 3), end=(1, 4), line="1+2\n"),
133        ])
134        node = parse_string("1\n", parser_class)
135        self.assertEqual(node, [
136            [[TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1\n")], None],
137            TokenInfo(NEWLINE, string="\n", start=(1, 1), end=(1, 2), line="1\n"),
138        ])
139
140    def test_optional_literal(self) -> None:
141        grammar = """
142        start: sum NEWLINE
143        sum: term '+' ?
144        term: NUMBER
145        """
146        parser_class = make_parser(grammar)
147        node = parse_string("1+\n", parser_class)
148        self.assertEqual(node, [
149            [
150                [TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1+\n")],
151                TokenInfo(OP, string="+", start=(1, 1), end=(1, 2), line="1+\n"),
152            ],
153            TokenInfo(NEWLINE, string="\n", start=(1, 2), end=(1, 3), line="1+\n"),
154        ])
155        node = parse_string("1\n", parser_class)
156        self.assertEqual(node, [
157            [[TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1\n")], None],
158            TokenInfo(NEWLINE, string="\n", start=(1, 1), end=(1, 2), line="1\n"),
159        ])
160
161    def test_alt_optional_operator(self) -> None:
162        grammar = """
163        start: sum NEWLINE
164        sum: term ['+' term]
165        term: NUMBER
166        """
167        parser_class = make_parser(grammar)
168        node = parse_string("1 + 2\n", parser_class)
169        self.assertEqual(node, [
170            [
171                [TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 + 2\n")],
172                [
173                    TokenInfo(OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2\n"),
174                    [TokenInfo(NUMBER, string="2", start=(1, 4), end=(1, 5), line="1 + 2\n")],
175                ],
176            ],
177            TokenInfo(NEWLINE, string="\n", start=(1, 5), end=(1, 6), line="1 + 2\n"),
178        ])
179        node = parse_string("1\n", parser_class)
180        self.assertEqual(node, [
181            [[TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1\n")], None],
182            TokenInfo(NEWLINE, string="\n", start=(1, 1), end=(1, 2), line="1\n"),
183        ])
184
185    def test_repeat_0_simple(self) -> None:
186        grammar = """
187        start: thing thing* NEWLINE
188        thing: NUMBER
189        """
190        parser_class = make_parser(grammar)
191        node = parse_string("1 2 3\n", parser_class)
192        self.assertEqual(node, [
193            [TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 2 3\n")],
194            [
195                [[TokenInfo(NUMBER, string="2", start=(1, 2), end=(1, 3), line="1 2 3\n")]],
196                [[TokenInfo(NUMBER, string="3", start=(1, 4), end=(1, 5), line="1 2 3\n")]],
197            ],
198            TokenInfo(NEWLINE, string="\n", start=(1, 5), end=(1, 6), line="1 2 3\n"),
199        ])
200        node = parse_string("1\n", parser_class)
201        self.assertEqual(node, [
202            [TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1\n")],
203            [],
204            TokenInfo(NEWLINE, string="\n", start=(1, 1), end=(1, 2), line="1\n"),
205        ])
206
207    def test_repeat_0_complex(self) -> None:
208        grammar = """
209        start: term ('+' term)* NEWLINE
210        term: NUMBER
211        """
212        parser_class = make_parser(grammar)
213        node = parse_string("1 + 2 + 3\n", parser_class)
214        self.assertEqual(node, [
215            [TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 + 2 + 3\n")],
216            [
217                [
218                    [
219                        TokenInfo(OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2 + 3\n"),
220                        [TokenInfo(NUMBER, string="2", start=(1, 4), end=(1, 5), line="1 + 2 + 3\n")],
221                    ]
222                ],
223                [
224                    [
225                        TokenInfo(OP, string="+", start=(1, 6), end=(1, 7), line="1 + 2 + 3\n"),
226                        [TokenInfo(NUMBER, string="3", start=(1, 8), end=(1, 9), line="1 + 2 + 3\n")],
227                    ]
228                ],
229            ],
230            TokenInfo(NEWLINE, string="\n", start=(1, 9), end=(1, 10), line="1 + 2 + 3\n"),
231        ])
232
233    def test_repeat_1_simple(self) -> None:
234        grammar = """
235        start: thing thing+ NEWLINE
236        thing: NUMBER
237        """
238        parser_class = make_parser(grammar)
239        node = parse_string("1 2 3\n", parser_class)
240        self.assertEqual(node, [
241            [TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 2 3\n")],
242            [
243                [[TokenInfo(NUMBER, string="2", start=(1, 2), end=(1, 3), line="1 2 3\n")]],
244                [[TokenInfo(NUMBER, string="3", start=(1, 4), end=(1, 5), line="1 2 3\n")]],
245            ],
246            TokenInfo(NEWLINE, string="\n", start=(1, 5), end=(1, 6), line="1 2 3\n"),
247        ])
248        with self.assertRaises(SyntaxError):
249            parse_string("1\n", parser_class)
250
251    def test_repeat_1_complex(self) -> None:
252        grammar = """
253        start: term ('+' term)+ NEWLINE
254        term: NUMBER
255        """
256        parser_class = make_parser(grammar)
257        node = parse_string("1 + 2 + 3\n", parser_class)
258        self.assertEqual(node, [
259            [TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 + 2 + 3\n")],
260            [
261                [
262                    [
263                        TokenInfo(OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2 + 3\n"),
264                        [TokenInfo(NUMBER, string="2", start=(1, 4), end=(1, 5), line="1 + 2 + 3\n")],
265                    ]
266                ],
267                [
268                    [
269                        TokenInfo(OP, string="+", start=(1, 6), end=(1, 7), line="1 + 2 + 3\n"),
270                        [TokenInfo(NUMBER, string="3", start=(1, 8), end=(1, 9), line="1 + 2 + 3\n")],
271                    ]
272                ],
273            ],
274            TokenInfo(NEWLINE, string="\n", start=(1, 9), end=(1, 10), line="1 + 2 + 3\n"),
275        ])
276        with self.assertRaises(SyntaxError):
277            parse_string("1\n", parser_class)
278
279    def test_repeat_with_sep_simple(self) -> None:
280        grammar = """
281        start: ','.thing+ NEWLINE
282        thing: NUMBER
283        """
284        parser_class = make_parser(grammar)
285        node = parse_string("1, 2, 3\n", parser_class)
286        self.assertEqual(node, [
287            [
288                [TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1, 2, 3\n")],
289                [TokenInfo(NUMBER, string="2", start=(1, 3), end=(1, 4), line="1, 2, 3\n")],
290                [TokenInfo(NUMBER, string="3", start=(1, 6), end=(1, 7), line="1, 2, 3\n")],
291            ],
292            TokenInfo(NEWLINE, string="\n", start=(1, 7), end=(1, 8), line="1, 2, 3\n"),
293        ])
294
295    def test_left_recursive(self) -> None:
296        grammar_source = """
297        start: expr NEWLINE
298        expr: ('-' term | expr '+' term | term)
299        term: NUMBER
300        foo: NAME+
301        bar: NAME*
302        baz: NAME?
303        """
304        grammar: Grammar = parse_string(grammar_source, GrammarParser)
305        parser_class = generate_parser(grammar)
306        rules = grammar.rules
307        self.assertFalse(rules["start"].left_recursive)
308        self.assertTrue(rules["expr"].left_recursive)
309        self.assertFalse(rules["term"].left_recursive)
310        self.assertFalse(rules["foo"].left_recursive)
311        self.assertFalse(rules["bar"].left_recursive)
312        self.assertFalse(rules["baz"].left_recursive)
313        node = parse_string("1 + 2 + 3\n", parser_class)
314        self.assertEqual(node, [
315            [
316                [
317                    [[TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 + 2 + 3\n")]],
318                    TokenInfo(OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2 + 3\n"),
319                    [TokenInfo(NUMBER, string="2", start=(1, 4), end=(1, 5), line="1 + 2 + 3\n")],
320                ],
321                TokenInfo(OP, string="+", start=(1, 6), end=(1, 7), line="1 + 2 + 3\n"),
322                [TokenInfo(NUMBER, string="3", start=(1, 8), end=(1, 9), line="1 + 2 + 3\n")],
323            ],
324            TokenInfo(NEWLINE, string="\n", start=(1, 9), end=(1, 10), line="1 + 2 + 3\n"),
325        ])
326
327    def test_python_expr(self) -> None:
328        grammar = """
329        start: expr NEWLINE? $ { ast.Expression(expr, lineno=1, col_offset=0) }
330        expr: ( expr '+' term { ast.BinOp(expr, ast.Add(), term, lineno=expr.lineno, col_offset=expr.col_offset, end_lineno=term.end_lineno, end_col_offset=term.end_col_offset) }
331            | expr '-' term { ast.BinOp(expr, ast.Sub(), term, lineno=expr.lineno, col_offset=expr.col_offset, end_lineno=term.end_lineno, end_col_offset=term.end_col_offset) }
332            | term { term }
333            )
334        term: ( l=term '*' r=factor { ast.BinOp(l, ast.Mult(), r, lineno=l.lineno, col_offset=l.col_offset, end_lineno=r.end_lineno, end_col_offset=r.end_col_offset) }
335            | l=term '/' r=factor { ast.BinOp(l, ast.Div(), r, lineno=l.lineno, col_offset=l.col_offset, end_lineno=r.end_lineno, end_col_offset=r.end_col_offset) }
336            | factor { factor }
337            )
338        factor: ( '(' expr ')' { expr }
339                | atom { atom }
340                )
341        atom: ( n=NAME { ast.Name(id=n.string, ctx=ast.Load(), lineno=n.start[0], col_offset=n.start[1], end_lineno=n.end[0], end_col_offset=n.end[1]) }
342            | n=NUMBER { ast.Constant(value=ast.literal_eval(n.string), lineno=n.start[0], col_offset=n.start[1], end_lineno=n.end[0], end_col_offset=n.end[1]) }
343            )
344        """
345        parser_class = make_parser(grammar)
346        node = parse_string("(1 + 2*3 + 5)/(6 - 2)\n", parser_class)
347        code = compile(node, "", "eval")
348        val = eval(code)
349        self.assertEqual(val, 3.0)
350
351    def test_nullable(self) -> None:
352        grammar_source = """
353        start: sign NUMBER
354        sign: ['-' | '+']
355        """
356        grammar: Grammar = parse_string(grammar_source, GrammarParser)
357        out = io.StringIO()
358        genr = PythonParserGenerator(grammar, out)
359        rules = grammar.rules
360        self.assertFalse(rules["start"].nullable)  # Not None!
361        self.assertTrue(rules["sign"].nullable)
362
363    def test_advanced_left_recursive(self) -> None:
364        grammar_source = """
365        start: NUMBER | sign start
366        sign: ['-']
367        """
368        grammar: Grammar = parse_string(grammar_source, GrammarParser)
369        out = io.StringIO()
370        genr = PythonParserGenerator(grammar, out)
371        rules = grammar.rules
372        self.assertFalse(rules["start"].nullable)  # Not None!
373        self.assertTrue(rules["sign"].nullable)
374        self.assertTrue(rules["start"].left_recursive)
375        self.assertFalse(rules["sign"].left_recursive)
376
377    def test_mutually_left_recursive(self) -> None:
378        grammar_source = """
379        start: foo 'E'
380        foo: bar 'A' | 'B'
381        bar: foo 'C' | 'D'
382        """
383        grammar: Grammar = parse_string(grammar_source, GrammarParser)
384        out = io.StringIO()
385        genr = PythonParserGenerator(grammar, out)
386        rules = grammar.rules
387        self.assertFalse(rules["start"].left_recursive)
388        self.assertTrue(rules["foo"].left_recursive)
389        self.assertTrue(rules["bar"].left_recursive)
390        genr.generate("<string>")
391        ns: Dict[str, Any] = {}
392        exec(out.getvalue(), ns)
393        parser_class: Type[Parser] = ns["GeneratedParser"]
394        node = parse_string("D A C A E", parser_class)
395        self.assertEqual(node, [
396            [
397                [
398                    [
399                        [TokenInfo(type=NAME, string="D", start=(1, 0), end=(1, 1), line="D A C A E")],
400                        TokenInfo(type=NAME, string="A", start=(1, 2), end=(1, 3), line="D A C A E"),
401                    ],
402                    TokenInfo(type=NAME, string="C", start=(1, 4), end=(1, 5), line="D A C A E"),
403                ],
404                TokenInfo(type=NAME, string="A", start=(1, 6), end=(1, 7), line="D A C A E"),
405            ],
406            TokenInfo(type=NAME, string="E", start=(1, 8), end=(1, 9), line="D A C A E"),
407        ])
408        node = parse_string("B C A E", parser_class)
409        self.assertIsNotNone(node)
410        self.assertEqual(node, [
411            [
412                [
413                    [TokenInfo(type=NAME, string="B", start=(1, 0), end=(1, 1), line="B C A E")],
414                    TokenInfo(type=NAME, string="C", start=(1, 2), end=(1, 3), line="B C A E"),
415                ],
416                TokenInfo(type=NAME, string="A", start=(1, 4), end=(1, 5), line="B C A E"),
417            ],
418            TokenInfo(type=NAME, string="E", start=(1, 6), end=(1, 7), line="B C A E"),
419        ])
420
421    def test_nasty_mutually_left_recursive(self) -> None:
422        # This grammar does not recognize 'x - + =', much to my chagrin.
423        # But that's the way PEG works.
424        # [Breathlessly]
425        # The problem is that the toplevel target call
426        # recurses into maybe, which recognizes 'x - +',
427        # and then the toplevel target looks for another '+',
428        # which fails, so it retreats to NAME,
429        # which succeeds, so we end up just recognizing 'x',
430        # and then start fails because there's no '=' after that.
431        grammar_source = """
432        start: target '='
433        target: maybe '+' | NAME
434        maybe: maybe '-' | target
435        """
436        grammar: Grammar = parse_string(grammar_source, GrammarParser)
437        out = io.StringIO()
438        genr = PythonParserGenerator(grammar, out)
439        genr.generate("<string>")
440        ns: Dict[str, Any] = {}
441        exec(out.getvalue(), ns)
442        parser_class = ns["GeneratedParser"]
443        with self.assertRaises(SyntaxError):
444            parse_string("x - + =", parser_class)
445
446    def test_lookahead(self) -> None:
447        grammar = """
448        start: (expr_stmt | assign_stmt) &'.'
449        expr_stmt: !(target '=') expr
450        assign_stmt: target '=' expr
451        expr: term ('+' term)*
452        target: NAME
453        term: NUMBER
454        """
455        parser_class = make_parser(grammar)
456        node = parse_string("foo = 12 + 12 .", parser_class)
457        self.assertEqual(node, [
458            [
459                [
460                    [TokenInfo(NAME, string="foo", start=(1, 0), end=(1, 3), line="foo = 12 + 12 .")],
461                    TokenInfo(OP, string="=", start=(1, 4), end=(1, 5), line="foo = 12 + 12 ."),
462                    [
463                        [
464                            TokenInfo(
465                                NUMBER, string="12", start=(1, 6), end=(1, 8), line="foo = 12 + 12 ."
466                            )
467                        ],
468                        [
469                            [
470                                [
471                                    TokenInfo(
472                                        OP,
473                                        string="+",
474                                        start=(1, 9),
475                                        end=(1, 10),
476                                        line="foo = 12 + 12 .",
477                                    ),
478                                    [
479                                        TokenInfo(
480                                            NUMBER,
481                                            string="12",
482                                            start=(1, 11),
483                                            end=(1, 13),
484                                            line="foo = 12 + 12 .",
485                                        )
486                                    ],
487                                ]
488                            ]
489                        ],
490                    ],
491                ]
492            ]
493        ])
494
495    def test_named_lookahead_error(self) -> None:
496        grammar = """
497        start: foo=!'x' NAME
498        """
499        with self.assertRaises(SyntaxError):
500            make_parser(grammar)
501
502    def test_start_leader(self) -> None:
503        grammar = """
504        start: attr | NAME
505        attr: start '.' NAME
506        """
507        # Would assert False without a special case in compute_left_recursives().
508        make_parser(grammar)
509
510    def test_opt_sequence(self) -> None:
511        grammar = """
512        start: [NAME*]
513        """
514        # This case was failing because of a double trailing comma at the end
515        # of a line in the generated source. See bpo-41044
516        make_parser(grammar)
517
518    def test_left_recursion_too_complex(self) -> None:
519        grammar = """
520        start: foo
521        foo: bar '+' | baz '+' | '+'
522        bar: baz '-' | foo '-' | '-'
523        baz: foo '*' | bar '*' | '*'
524        """
525        with self.assertRaises(ValueError) as errinfo:
526            make_parser(grammar)
527            self.assertTrue("no leader" in str(errinfo.exception.value))
528
529    def test_cut(self) -> None:
530        grammar = """
531        start: '(' ~ expr ')'
532        expr: NUMBER
533        """
534        parser_class = make_parser(grammar)
535        node = parse_string("(1)", parser_class)
536        self.assertEqual(node, [
537            TokenInfo(OP, string="(", start=(1, 0), end=(1, 1), line="(1)"),
538            [TokenInfo(NUMBER, string="1", start=(1, 1), end=(1, 2), line="(1)")],
539            TokenInfo(OP, string=")", start=(1, 2), end=(1, 3), line="(1)"),
540        ])
541
542    def test_dangling_reference(self) -> None:
543        grammar = """
544        start: foo ENDMARKER
545        foo: bar NAME
546        """
547        with self.assertRaises(GrammarError):
548            parser_class = make_parser(grammar)
549
550    def test_bad_token_reference(self) -> None:
551        grammar = """
552        start: foo
553        foo: NAMEE
554        """
555        with self.assertRaises(GrammarError):
556            parser_class = make_parser(grammar)
557
558    def test_missing_start(self) -> None:
559        grammar = """
560        foo: NAME
561        """
562        with self.assertRaises(GrammarError):
563            parser_class = make_parser(grammar)
564
565    def test_invalid_rule_name(self) -> None:
566        grammar = """
567        start: _a b
568        _a: 'a'
569        b: 'b'
570        """
571        with self.assertRaisesRegex(GrammarError, "cannot start with underscore: '_a'"):
572            parser_class = make_parser(grammar)
573
574    def test_invalid_variable_name(self) -> None:
575        grammar = """
576        start: a b
577        a: _x='a'
578        b: 'b'
579        """
580        with self.assertRaisesRegex(GrammarError, "cannot start with underscore: '_x'"):
581            parser_class = make_parser(grammar)
582
583    def test_invalid_variable_name_in_temporal_rule(self) -> None:
584        grammar = """
585        start: a b
586        a: (_x='a' | 'b') | 'c'
587        b: 'b'
588        """
589        with self.assertRaisesRegex(GrammarError, "cannot start with underscore: '_x'"):
590            parser_class = make_parser(grammar)
591
592
593class TestGrammarVisitor:
594    class Visitor(GrammarVisitor):
595        def __init__(self) -> None:
596            self.n_nodes = 0
597
598        def visit(self, node: Any, *args: Any, **kwargs: Any) -> None:
599            self.n_nodes += 1
600            super().visit(node, *args, **kwargs)
601
602    def test_parse_trivial_grammar(self) -> None:
603        grammar = """
604        start: 'a'
605        """
606        rules = parse_string(grammar, GrammarParser)
607        visitor = self.Visitor()
608
609        visitor.visit(rules)
610
611        self.assertEqual(visitor.n_nodes, 6)
612
613    def test_parse_or_grammar(self) -> None:
614        grammar = """
615        start: rule
616        rule: 'a' | 'b'
617        """
618        rules = parse_string(grammar, GrammarParser)
619        visitor = self.Visitor()
620
621        visitor.visit(rules)
622
623        # Grammar/Rule/Rhs/Alt/NamedItem/NameLeaf   -> 6
624        #         Rule/Rhs/                         -> 2
625        #                  Alt/NamedItem/StringLeaf -> 3
626        #                  Alt/NamedItem/StringLeaf -> 3
627
628        self.assertEqual(visitor.n_nodes, 14)
629
630    def test_parse_repeat1_grammar(self) -> None:
631        grammar = """
632        start: 'a'+
633        """
634        rules = parse_string(grammar, GrammarParser)
635        visitor = self.Visitor()
636
637        visitor.visit(rules)
638
639        # Grammar/Rule/Rhs/Alt/NamedItem/Repeat1/StringLeaf -> 6
640        self.assertEqual(visitor.n_nodes, 7)
641
642    def test_parse_repeat0_grammar(self) -> None:
643        grammar = """
644        start: 'a'*
645        """
646        rules = parse_string(grammar, GrammarParser)
647        visitor = self.Visitor()
648
649        visitor.visit(rules)
650
651        # Grammar/Rule/Rhs/Alt/NamedItem/Repeat0/StringLeaf -> 6
652
653        self.assertEqual(visitor.n_nodes, 7)
654
655    def test_parse_optional_grammar(self) -> None:
656        grammar = """
657        start: 'a' ['b']
658        """
659        rules = parse_string(grammar, GrammarParser)
660        visitor = self.Visitor()
661
662        visitor.visit(rules)
663
664        # Grammar/Rule/Rhs/Alt/NamedItem/StringLeaf                       -> 6
665        #                      NamedItem/Opt/Rhs/Alt/NamedItem/Stringleaf -> 6
666
667        self.assertEqual(visitor.n_nodes, 12)
668
669
670class TestGrammarVisualizer(unittest.TestCase):
671    def test_simple_rule(self) -> None:
672        grammar = """
673        start: 'a' 'b'
674        """
675        rules = parse_string(grammar, GrammarParser)
676
677        printer = ASTGrammarPrinter()
678        lines: List[str] = []
679        printer.print_grammar_ast(rules, printer=lines.append)
680
681        output = "\n".join(lines)
682        expected_output = textwrap.dedent(
683            """\
684        └──Rule
685           └──Rhs
686              └──Alt
687                 ├──NamedItem
688                 │  └──StringLeaf("'a'")
689                 └──NamedItem
690                    └──StringLeaf("'b'")
691        """
692        )
693
694        self.assertEqual(output, expected_output)
695
696    def test_multiple_rules(self) -> None:
697        grammar = """
698        start: a b
699        a: 'a'
700        b: 'b'
701        """
702        rules = parse_string(grammar, GrammarParser)
703
704        printer = ASTGrammarPrinter()
705        lines: List[str] = []
706        printer.print_grammar_ast(rules, printer=lines.append)
707
708        output = "\n".join(lines)
709        expected_output = textwrap.dedent(
710            """\
711        └──Rule
712           └──Rhs
713              └──Alt
714                 ├──NamedItem
715                 │  └──NameLeaf('a')
716                 └──NamedItem
717                    └──NameLeaf('b')
718
719        └──Rule
720           └──Rhs
721              └──Alt
722                 └──NamedItem
723                    └──StringLeaf("'a'")
724
725        └──Rule
726           └──Rhs
727              └──Alt
728                 └──NamedItem
729                    └──StringLeaf("'b'")
730                        """
731        )
732
733        self.assertEqual(output, expected_output)
734
735    def test_deep_nested_rule(self) -> None:
736        grammar = """
737        start: 'a' ['b'['c'['d']]]
738        """
739        rules = parse_string(grammar, GrammarParser)
740
741        printer = ASTGrammarPrinter()
742        lines: List[str] = []
743        printer.print_grammar_ast(rules, printer=lines.append)
744
745        output = "\n".join(lines)
746        expected_output = textwrap.dedent(
747            """\
748        └──Rule
749           └──Rhs
750              └──Alt
751                 ├──NamedItem
752                 │  └──StringLeaf("'a'")
753                 └──NamedItem
754                    └──Opt
755                       └──Rhs
756                          └──Alt
757                             ├──NamedItem
758                             │  └──StringLeaf("'b'")
759                             └──NamedItem
760                                └──Opt
761                                   └──Rhs
762                                      └──Alt
763                                         ├──NamedItem
764                                         │  └──StringLeaf("'c'")
765                                         └──NamedItem
766                                            └──Opt
767                                               └──Rhs
768                                                  └──Alt
769                                                     └──NamedItem
770                                                        └──StringLeaf("'d'")
771                                """
772        )
773
774        self.assertEqual(output, expected_output)
775