1import antlr3
2import testbase
3import unittest
4
5
6class T(testbase.ANTLRTest):
7    def setUp(self):
8        self.compileGrammar(options='-trace')
9
10
11    def lexerClass(self, base):
12        class TLexer(base):
13            def __init__(self, *args, **kwargs):
14                super().__init__(*args, **kwargs)
15
16                self.traces = []
17
18
19            def traceIn(self, ruleName, ruleIndex):
20                self.traces.append('>'+ruleName)
21
22
23            def traceOut(self, ruleName, ruleIndex):
24                self.traces.append('<'+ruleName)
25
26
27            def recover(self, input, re):
28                # no error recovery yet, just crash!
29                raise
30
31        return TLexer
32
33
34    def parserClass(self, base):
35        class TParser(base):
36            def __init__(self, *args, **kwargs):
37                super().__init__(*args, **kwargs)
38
39                self.traces = []
40
41
42            def traceIn(self, ruleName, ruleIndex):
43                self.traces.append('>'+ruleName)
44
45
46            def traceOut(self, ruleName, ruleIndex):
47                self.traces.append('<'+ruleName)
48
49
50            def recover(self, input, re):
51                # no error recovery yet, just crash!
52                raise
53
54            def getRuleInvocationStack(self):
55                return self._getRuleInvocationStack(base.__module__)
56
57        return TParser
58
59
60    def testTrace(self):
61        cStream = antlr3.StringStream('< 1 + 2 + 3 >')
62        lexer = self.getLexer(cStream)
63        tStream = antlr3.CommonTokenStream(lexer)
64        parser = self.getParser(tStream)
65        parser.a()
66
67        self.assertEqual(
68            lexer.traces,
69            [ '>T__7', '<T__7', '>WS', '<WS', '>INT', '<INT', '>WS', '<WS',
70              '>T__6', '<T__6', '>WS', '<WS', '>INT', '<INT', '>WS', '<WS',
71              '>T__6', '<T__6', '>WS', '<WS', '>INT', '<INT', '>WS', '<WS',
72              '>T__8', '<T__8']
73            )
74
75        self.assertEqual(
76            parser.traces,
77            [ '>a', '>synpred1_t044trace_fragment', '<synpred1_t044trace_fragment', '>b', '>c',
78              '<c', '>c', '<c', '>c', '<c', '<b', '<a' ]
79            )
80
81
82    def testInvokationStack(self):
83        cStream = antlr3.StringStream('< 1 + 2 + 3 >')
84        lexer = self.getLexer(cStream)
85        tStream = antlr3.CommonTokenStream(lexer)
86        parser = self.getParser(tStream)
87        parser.a()
88
89        self.assertEqual(parser._stack, ['a', 'b', 'c'])
90
91if __name__ == '__main__':
92    unittest.main()
93