1import antlr3
2import testbase
3import unittest
4
5class t017parser(testbase.ANTLRTest):
6    def setUp(self):
7        self.compileGrammar()
8
9    def parserClass(self, base):
10        class TestParser(base):
11            def __init__(self, *args, **kwargs):
12                super().__init__(*args, **kwargs)
13
14                self.reportedErrors = []
15
16
17            def emitErrorMessage(self, msg):
18                self.reportedErrors.append(msg)
19
20        return TestParser
21
22
23    def testValid(self):
24        cStream = antlr3.StringStream("int foo;")
25        lexer = self.getLexer(cStream)
26        tStream = antlr3.CommonTokenStream(lexer)
27        parser = self.getParser(tStream)
28        parser.program()
29
30        self.assertEqual(parser.reportedErrors, [])
31
32
33    def testMalformedInput1(self):
34        cStream = antlr3.StringStream('int foo() { 1+2 }')
35        lexer = self.getLexer(cStream)
36        tStream = antlr3.CommonTokenStream(lexer)
37        parser = self.getParser(tStream)
38        parser.program()
39
40        # FIXME: currently strings with formatted errors are collected
41        # can't check error locations yet
42        self.assertEqual(len(parser.reportedErrors), 1, parser.reportedErrors)
43
44
45    def testMalformedInput2(self):
46        cStream = antlr3.StringStream('int foo() { 1+; 1+2 }')
47        lexer = self.getLexer(cStream)
48        tStream = antlr3.CommonTokenStream(lexer)
49        parser = self.getParser(tStream)
50        parser.program()
51
52        # FIXME: currently strings with formatted errors are collected
53        # can't check error locations yet
54        self.assertEqual(len(parser.reportedErrors), 2, parser.reportedErrors)
55
56
57if __name__ == '__main__':
58    unittest.main()
59