1"""Recognize image file formats based on their first few bytes."""
2
3from os import PathLike
4
5__all__ = ["what"]
6
7#-------------------------#
8# Recognize image headers #
9#-------------------------#
10
11def what(file, h=None):
12    f = None
13    try:
14        if h is None:
15            if isinstance(file, (str, PathLike)):
16                f = open(file, 'rb')
17                h = f.read(32)
18            else:
19                location = file.tell()
20                h = file.read(32)
21                file.seek(location)
22        for tf in tests:
23            res = tf(h, f)
24            if res:
25                return res
26    finally:
27        if f: f.close()
28    return None
29
30
31#---------------------------------#
32# Subroutines per image file type #
33#---------------------------------#
34
35tests = []
36
37def test_jpeg(h, f):
38    """JPEG data in JFIF or Exif format"""
39    if h[6:10] in (b'JFIF', b'Exif'):
40        return 'jpeg'
41
42tests.append(test_jpeg)
43
44def test_png(h, f):
45    if h.startswith(b'\211PNG\r\n\032\n'):
46        return 'png'
47
48tests.append(test_png)
49
50def test_gif(h, f):
51    """GIF ('87 and '89 variants)"""
52    if h[:6] in (b'GIF87a', b'GIF89a'):
53        return 'gif'
54
55tests.append(test_gif)
56
57def test_tiff(h, f):
58    """TIFF (can be in Motorola or Intel byte order)"""
59    if h[:2] in (b'MM', b'II'):
60        return 'tiff'
61
62tests.append(test_tiff)
63
64def test_rgb(h, f):
65    """SGI image library"""
66    if h.startswith(b'\001\332'):
67        return 'rgb'
68
69tests.append(test_rgb)
70
71def test_pbm(h, f):
72    """PBM (portable bitmap)"""
73    if len(h) >= 3 and \
74        h[0] == ord(b'P') and h[1] in b'14' and h[2] in b' \t\n\r':
75        return 'pbm'
76
77tests.append(test_pbm)
78
79def test_pgm(h, f):
80    """PGM (portable graymap)"""
81    if len(h) >= 3 and \
82        h[0] == ord(b'P') and h[1] in b'25' and h[2] in b' \t\n\r':
83        return 'pgm'
84
85tests.append(test_pgm)
86
87def test_ppm(h, f):
88    """PPM (portable pixmap)"""
89    if len(h) >= 3 and \
90        h[0] == ord(b'P') and h[1] in b'36' and h[2] in b' \t\n\r':
91        return 'ppm'
92
93tests.append(test_ppm)
94
95def test_rast(h, f):
96    """Sun raster file"""
97    if h.startswith(b'\x59\xA6\x6A\x95'):
98        return 'rast'
99
100tests.append(test_rast)
101
102def test_xbm(h, f):
103    """X bitmap (X10 or X11)"""
104    if h.startswith(b'#define '):
105        return 'xbm'
106
107tests.append(test_xbm)
108
109def test_bmp(h, f):
110    if h.startswith(b'BM'):
111        return 'bmp'
112
113tests.append(test_bmp)
114
115def test_webp(h, f):
116    if h.startswith(b'RIFF') and h[8:12] == b'WEBP':
117        return 'webp'
118
119tests.append(test_webp)
120
121def test_exr(h, f):
122    if h.startswith(b'\x76\x2f\x31\x01'):
123        return 'exr'
124
125tests.append(test_exr)
126
127#--------------------#
128# Small test program #
129#--------------------#
130
131def test():
132    import sys
133    recursive = 0
134    if sys.argv[1:] and sys.argv[1] == '-r':
135        del sys.argv[1:2]
136        recursive = 1
137    try:
138        if sys.argv[1:]:
139            testall(sys.argv[1:], recursive, 1)
140        else:
141            testall(['.'], recursive, 1)
142    except KeyboardInterrupt:
143        sys.stderr.write('\n[Interrupted]\n')
144        sys.exit(1)
145
146def testall(list, recursive, toplevel):
147    import sys
148    import os
149    for filename in list:
150        if os.path.isdir(filename):
151            print(filename + '/:', end=' ')
152            if recursive or toplevel:
153                print('recursing down:')
154                import glob
155                names = glob.glob(os.path.join(filename, '*'))
156                testall(names, recursive, 0)
157            else:
158                print('*** directory (use -r) ***')
159        else:
160            print(filename + ':', end=' ')
161            sys.stdout.flush()
162            try:
163                print(what(filename))
164            except OSError:
165                print('*** not found ***')
166
167if __name__ == '__main__':
168    test()
169