1import imghdr
2import io
3import sys
4import unittest
5from test.test_support import findfile, TESTFN, unlink, run_unittest
6
7TEST_FILES = (
8    ('python.png', 'png'),
9    ('python.gif', 'gif'),
10    ('python.bmp', 'bmp'),
11    ('python.ppm', 'ppm'),
12    ('python.pgm', 'pgm'),
13    ('python.pbm', 'pbm'),
14    ('python.jpg', 'jpeg'),
15    ('python.ras', 'rast'),
16    ('python.sgi', 'rgb'),
17    ('python.tiff', 'tiff'),
18    ('python.xbm', 'xbm')
19)
20
21class UnseekableIO(io.FileIO):
22    def tell(self):
23        raise io.UnsupportedOperation
24
25    def seek(self, *args, **kwargs):
26        raise io.UnsupportedOperation
27
28class TestImghdr(unittest.TestCase):
29    @classmethod
30    def setUpClass(cls):
31        cls.testfile = findfile('python.png', subdir='imghdrdata')
32        with open(cls.testfile, 'rb') as stream:
33            cls.testdata = stream.read()
34
35    def tearDown(self):
36        unlink(TESTFN)
37
38    def test_data(self):
39        for filename, expected in TEST_FILES:
40            filename = findfile(filename, subdir='imghdrdata')
41            self.assertEqual(imghdr.what(filename), expected)
42            ufilename = filename.decode(sys.getfilesystemencoding())
43            self.assertEqual(imghdr.what(ufilename), expected)
44            with open(filename, 'rb') as stream:
45                self.assertEqual(imghdr.what(stream), expected)
46            with open(filename, 'rb') as stream:
47                data = stream.read()
48            self.assertEqual(imghdr.what(None, data), expected)
49
50    def test_register_test(self):
51        def test_jumbo(h, file):
52            if h.startswith(b'eggs'):
53                return 'ham'
54        imghdr.tests.append(test_jumbo)
55        self.addCleanup(imghdr.tests.pop)
56        self.assertEqual(imghdr.what(None, b'eggs'), 'ham')
57
58    def test_file_pos(self):
59        with open(TESTFN, 'wb') as stream:
60            stream.write(b'ababagalamaga')
61            pos = stream.tell()
62            stream.write(self.testdata)
63        with open(TESTFN, 'rb') as stream:
64            stream.seek(pos)
65            self.assertEqual(imghdr.what(stream), 'png')
66            self.assertEqual(stream.tell(), pos)
67
68    def test_bad_args(self):
69        with self.assertRaises(TypeError):
70            imghdr.what()
71        with self.assertRaises(AttributeError):
72            imghdr.what(None)
73        with self.assertRaises(TypeError):
74            imghdr.what(self.testfile, 1)
75        with open(self.testfile, 'rb') as f:
76            with self.assertRaises(AttributeError):
77                imghdr.what(f.fileno())
78
79    def test_invalid_headers(self):
80        for header in (b'\211PN\r\n',
81                       b'\001\331',
82                       b'\x59\xA6',
83                       b'cutecat',
84                       b'000000JFI',
85                       b'GIF80'):
86            self.assertIsNone(imghdr.what(None, header))
87
88    def test_missing_file(self):
89        with self.assertRaises(IOError):
90            imghdr.what('missing')
91
92    def test_closed_file(self):
93        stream = open(self.testfile, 'rb')
94        stream.close()
95        with self.assertRaises(ValueError) as cm:
96            imghdr.what(stream)
97        stream = io.BytesIO(self.testdata)
98        stream.close()
99        with self.assertRaises(ValueError) as cm:
100            imghdr.what(stream)
101
102    def test_unseekable(self):
103        with open(TESTFN, 'wb') as stream:
104            stream.write(self.testdata)
105        with UnseekableIO(TESTFN, 'rb') as stream:
106            with self.assertRaises(io.UnsupportedOperation):
107                imghdr.what(stream)
108
109    def test_output_stream(self):
110        with open(TESTFN, 'wb') as stream:
111            stream.write(self.testdata)
112            stream.seek(0)
113            with self.assertRaises(IOError) as cm:
114                imghdr.what(stream)
115
116def test_main():
117    run_unittest(TestImghdr)
118
119if __name__ == '__main__':
120    test_main()
121