1from test import support
2from test.support import bigmemtest, _4G
3
4import unittest
5from io import BytesIO, DEFAULT_BUFFER_SIZE
6import os
7import pickle
8import glob
9import pathlib
10import random
11import shutil
12import subprocess
13import sys
14from test.support import unlink
15import _compression
16
17try:
18    import threading
19except ImportError:
20    threading = None
21
22# Skip tests if the bz2 module doesn't exist.
23bz2 = support.import_module('bz2')
24from bz2 import BZ2File, BZ2Compressor, BZ2Decompressor
25
26has_cmdline_bunzip2 = None
27
28def ext_decompress(data):
29    global has_cmdline_bunzip2
30    if has_cmdline_bunzip2 is None:
31        has_cmdline_bunzip2 = bool(shutil.which('bunzip2'))
32    if has_cmdline_bunzip2:
33        return subprocess.check_output(['bunzip2'], input=data)
34    else:
35        return bz2.decompress(data)
36
37class BaseTest(unittest.TestCase):
38    "Base for other testcases."
39
40    TEXT_LINES = [
41        b'root:x:0:0:root:/root:/bin/bash\n',
42        b'bin:x:1:1:bin:/bin:\n',
43        b'daemon:x:2:2:daemon:/sbin:\n',
44        b'adm:x:3:4:adm:/var/adm:\n',
45        b'lp:x:4:7:lp:/var/spool/lpd:\n',
46        b'sync:x:5:0:sync:/sbin:/bin/sync\n',
47        b'shutdown:x:6:0:shutdown:/sbin:/sbin/shutdown\n',
48        b'halt:x:7:0:halt:/sbin:/sbin/halt\n',
49        b'mail:x:8:12:mail:/var/spool/mail:\n',
50        b'news:x:9:13:news:/var/spool/news:\n',
51        b'uucp:x:10:14:uucp:/var/spool/uucp:\n',
52        b'operator:x:11:0:operator:/root:\n',
53        b'games:x:12:100:games:/usr/games:\n',
54        b'gopher:x:13:30:gopher:/usr/lib/gopher-data:\n',
55        b'ftp:x:14:50:FTP User:/var/ftp:/bin/bash\n',
56        b'nobody:x:65534:65534:Nobody:/home:\n',
57        b'postfix:x:100:101:postfix:/var/spool/postfix:\n',
58        b'niemeyer:x:500:500::/home/niemeyer:/bin/bash\n',
59        b'postgres:x:101:102:PostgreSQL Server:/var/lib/pgsql:/bin/bash\n',
60        b'mysql:x:102:103:MySQL server:/var/lib/mysql:/bin/bash\n',
61        b'www:x:103:104::/var/www:/bin/false\n',
62        ]
63    TEXT = b''.join(TEXT_LINES)
64    DATA = b'BZh91AY&SY.\xc8N\x18\x00\x01>_\x80\x00\x10@\x02\xff\xf0\x01\x07n\x00?\xe7\xff\xe00\x01\x99\xaa\x00\xc0\x03F\x86\x8c#&\x83F\x9a\x03\x06\xa6\xd0\xa6\x93M\x0fQ\xa7\xa8\x06\x804hh\x12$\x11\xa4i4\xf14S\xd2<Q\xb5\x0fH\xd3\xd4\xdd\xd5\x87\xbb\xf8\x94\r\x8f\xafI\x12\xe1\xc9\xf8/E\x00pu\x89\x12]\xc9\xbbDL\nQ\x0e\t1\x12\xdf\xa0\xc0\x97\xac2O9\x89\x13\x94\x0e\x1c7\x0ed\x95I\x0c\xaaJ\xa4\x18L\x10\x05#\x9c\xaf\xba\xbc/\x97\x8a#C\xc8\xe1\x8cW\xf9\xe2\xd0\xd6M\xa7\x8bXa<e\x84t\xcbL\xb3\xa7\xd9\xcd\xd1\xcb\x84.\xaf\xb3\xab\xab\xad`n}\xa0lh\tE,\x8eZ\x15\x17VH>\x88\xe5\xcd9gd6\x0b\n\xe9\x9b\xd5\x8a\x99\xf7\x08.K\x8ev\xfb\xf7xw\xbb\xdf\xa1\x92\xf1\xdd|/";\xa2\xba\x9f\xd5\xb1#A\xb6\xf6\xb3o\xc9\xc5y\\\xebO\xe7\x85\x9a\xbc\xb6f8\x952\xd5\xd7"%\x89>V,\xf7\xa6z\xe2\x9f\xa3\xdf\x11\x11"\xd6E)I\xa9\x13^\xca\xf3r\xd0\x03U\x922\xf26\xec\xb6\xed\x8b\xc3U\x13\x9d\xc5\x170\xa4\xfa^\x92\xacDF\x8a\x97\xd6\x19\xfe\xdd\xb8\xbd\x1a\x9a\x19\xa3\x80ankR\x8b\xe5\xd83]\xa9\xc6\x08\x82f\xf6\xb9"6l$\xb8j@\xc0\x8a\xb0l1..\xbak\x83ls\x15\xbc\xf4\xc1\x13\xbe\xf8E\xb8\x9d\r\xa8\x9dk\x84\xd3n\xfa\xacQ\x07\xb1%y\xaav\xb4\x08\xe0z\x1b\x16\xf5\x04\xe9\xcc\xb9\x08z\x1en7.G\xfc]\xc9\x14\xe1B@\xbb!8`'
65    EMPTY_DATA = b'BZh9\x17rE8P\x90\x00\x00\x00\x00'
66    BAD_DATA = b'this is not a valid bzip2 file'
67
68    # Some tests need more than one block of uncompressed data. Since one block
69    # is at least 100 kB, we gather some data dynamically and compress it.
70    # Note that this assumes that compression works correctly, so we cannot
71    # simply use the bigger test data for all tests.
72    test_size = 0
73    BIG_TEXT = bytearray(128*1024)
74    for fname in glob.glob(os.path.join(os.path.dirname(__file__), '*.py')):
75        with open(fname, 'rb') as fh:
76            test_size += fh.readinto(memoryview(BIG_TEXT)[test_size:])
77        if test_size > 128*1024:
78            break
79    BIG_DATA = bz2.compress(BIG_TEXT, compresslevel=1)
80
81    def setUp(self):
82        self.filename = support.TESTFN
83
84    def tearDown(self):
85        if os.path.isfile(self.filename):
86            os.unlink(self.filename)
87
88
89class BZ2FileTest(BaseTest):
90    "Test the BZ2File class."
91
92    def createTempFile(self, streams=1, suffix=b""):
93        with open(self.filename, "wb") as f:
94            f.write(self.DATA * streams)
95            f.write(suffix)
96
97    def testBadArgs(self):
98        self.assertRaises(TypeError, BZ2File, 123.456)
99        self.assertRaises(ValueError, BZ2File, os.devnull, "z")
100        self.assertRaises(ValueError, BZ2File, os.devnull, "rx")
101        self.assertRaises(ValueError, BZ2File, os.devnull, "rbt")
102        self.assertRaises(ValueError, BZ2File, os.devnull, compresslevel=0)
103        self.assertRaises(ValueError, BZ2File, os.devnull, compresslevel=10)
104
105    def testRead(self):
106        self.createTempFile()
107        with BZ2File(self.filename) as bz2f:
108            self.assertRaises(TypeError, bz2f.read, float())
109            self.assertEqual(bz2f.read(), self.TEXT)
110
111    def testReadBadFile(self):
112        self.createTempFile(streams=0, suffix=self.BAD_DATA)
113        with BZ2File(self.filename) as bz2f:
114            self.assertRaises(OSError, bz2f.read)
115
116    def testReadMultiStream(self):
117        self.createTempFile(streams=5)
118        with BZ2File(self.filename) as bz2f:
119            self.assertRaises(TypeError, bz2f.read, float())
120            self.assertEqual(bz2f.read(), self.TEXT * 5)
121
122    def testReadMonkeyMultiStream(self):
123        # Test BZ2File.read() on a multi-stream archive where a stream
124        # boundary coincides with the end of the raw read buffer.
125        buffer_size = _compression.BUFFER_SIZE
126        _compression.BUFFER_SIZE = len(self.DATA)
127        try:
128            self.createTempFile(streams=5)
129            with BZ2File(self.filename) as bz2f:
130                self.assertRaises(TypeError, bz2f.read, float())
131                self.assertEqual(bz2f.read(), self.TEXT * 5)
132        finally:
133            _compression.BUFFER_SIZE = buffer_size
134
135    def testReadTrailingJunk(self):
136        self.createTempFile(suffix=self.BAD_DATA)
137        with BZ2File(self.filename) as bz2f:
138            self.assertEqual(bz2f.read(), self.TEXT)
139
140    def testReadMultiStreamTrailingJunk(self):
141        self.createTempFile(streams=5, suffix=self.BAD_DATA)
142        with BZ2File(self.filename) as bz2f:
143            self.assertEqual(bz2f.read(), self.TEXT * 5)
144
145    def testRead0(self):
146        self.createTempFile()
147        with BZ2File(self.filename) as bz2f:
148            self.assertRaises(TypeError, bz2f.read, float())
149            self.assertEqual(bz2f.read(0), b"")
150
151    def testReadChunk10(self):
152        self.createTempFile()
153        with BZ2File(self.filename) as bz2f:
154            text = b''
155            while True:
156                str = bz2f.read(10)
157                if not str:
158                    break
159                text += str
160            self.assertEqual(text, self.TEXT)
161
162    def testReadChunk10MultiStream(self):
163        self.createTempFile(streams=5)
164        with BZ2File(self.filename) as bz2f:
165            text = b''
166            while True:
167                str = bz2f.read(10)
168                if not str:
169                    break
170                text += str
171            self.assertEqual(text, self.TEXT * 5)
172
173    def testRead100(self):
174        self.createTempFile()
175        with BZ2File(self.filename) as bz2f:
176            self.assertEqual(bz2f.read(100), self.TEXT[:100])
177
178    def testPeek(self):
179        self.createTempFile()
180        with BZ2File(self.filename) as bz2f:
181            pdata = bz2f.peek()
182            self.assertNotEqual(len(pdata), 0)
183            self.assertTrue(self.TEXT.startswith(pdata))
184            self.assertEqual(bz2f.read(), self.TEXT)
185
186    def testReadInto(self):
187        self.createTempFile()
188        with BZ2File(self.filename) as bz2f:
189            n = 128
190            b = bytearray(n)
191            self.assertEqual(bz2f.readinto(b), n)
192            self.assertEqual(b, self.TEXT[:n])
193            n = len(self.TEXT) - n
194            b = bytearray(len(self.TEXT))
195            self.assertEqual(bz2f.readinto(b), n)
196            self.assertEqual(b[:n], self.TEXT[-n:])
197
198    def testReadLine(self):
199        self.createTempFile()
200        with BZ2File(self.filename) as bz2f:
201            self.assertRaises(TypeError, bz2f.readline, None)
202            for line in self.TEXT_LINES:
203                self.assertEqual(bz2f.readline(), line)
204
205    def testReadLineMultiStream(self):
206        self.createTempFile(streams=5)
207        with BZ2File(self.filename) as bz2f:
208            self.assertRaises(TypeError, bz2f.readline, None)
209            for line in self.TEXT_LINES * 5:
210                self.assertEqual(bz2f.readline(), line)
211
212    def testReadLines(self):
213        self.createTempFile()
214        with BZ2File(self.filename) as bz2f:
215            self.assertRaises(TypeError, bz2f.readlines, None)
216            self.assertEqual(bz2f.readlines(), self.TEXT_LINES)
217
218    def testReadLinesMultiStream(self):
219        self.createTempFile(streams=5)
220        with BZ2File(self.filename) as bz2f:
221            self.assertRaises(TypeError, bz2f.readlines, None)
222            self.assertEqual(bz2f.readlines(), self.TEXT_LINES * 5)
223
224    def testIterator(self):
225        self.createTempFile()
226        with BZ2File(self.filename) as bz2f:
227            self.assertEqual(list(iter(bz2f)), self.TEXT_LINES)
228
229    def testIteratorMultiStream(self):
230        self.createTempFile(streams=5)
231        with BZ2File(self.filename) as bz2f:
232            self.assertEqual(list(iter(bz2f)), self.TEXT_LINES * 5)
233
234    def testClosedIteratorDeadlock(self):
235        # Issue #3309: Iteration on a closed BZ2File should release the lock.
236        self.createTempFile()
237        bz2f = BZ2File(self.filename)
238        bz2f.close()
239        self.assertRaises(ValueError, next, bz2f)
240        # This call will deadlock if the above call failed to release the lock.
241        self.assertRaises(ValueError, bz2f.readlines)
242
243    def testWrite(self):
244        with BZ2File(self.filename, "w") as bz2f:
245            self.assertRaises(TypeError, bz2f.write)
246            bz2f.write(self.TEXT)
247        with open(self.filename, 'rb') as f:
248            self.assertEqual(ext_decompress(f.read()), self.TEXT)
249
250    def testWriteChunks10(self):
251        with BZ2File(self.filename, "w") as bz2f:
252            n = 0
253            while True:
254                str = self.TEXT[n*10:(n+1)*10]
255                if not str:
256                    break
257                bz2f.write(str)
258                n += 1
259        with open(self.filename, 'rb') as f:
260            self.assertEqual(ext_decompress(f.read()), self.TEXT)
261
262    def testWriteNonDefaultCompressLevel(self):
263        expected = bz2.compress(self.TEXT, compresslevel=5)
264        with BZ2File(self.filename, "w", compresslevel=5) as bz2f:
265            bz2f.write(self.TEXT)
266        with open(self.filename, "rb") as f:
267            self.assertEqual(f.read(), expected)
268
269    def testWriteLines(self):
270        with BZ2File(self.filename, "w") as bz2f:
271            self.assertRaises(TypeError, bz2f.writelines)
272            bz2f.writelines(self.TEXT_LINES)
273        # Issue #1535500: Calling writelines() on a closed BZ2File
274        # should raise an exception.
275        self.assertRaises(ValueError, bz2f.writelines, ["a"])
276        with open(self.filename, 'rb') as f:
277            self.assertEqual(ext_decompress(f.read()), self.TEXT)
278
279    def testWriteMethodsOnReadOnlyFile(self):
280        with BZ2File(self.filename, "w") as bz2f:
281            bz2f.write(b"abc")
282
283        with BZ2File(self.filename, "r") as bz2f:
284            self.assertRaises(OSError, bz2f.write, b"a")
285            self.assertRaises(OSError, bz2f.writelines, [b"a"])
286
287    def testAppend(self):
288        with BZ2File(self.filename, "w") as bz2f:
289            self.assertRaises(TypeError, bz2f.write)
290            bz2f.write(self.TEXT)
291        with BZ2File(self.filename, "a") as bz2f:
292            self.assertRaises(TypeError, bz2f.write)
293            bz2f.write(self.TEXT)
294        with open(self.filename, 'rb') as f:
295            self.assertEqual(ext_decompress(f.read()), self.TEXT * 2)
296
297    def testSeekForward(self):
298        self.createTempFile()
299        with BZ2File(self.filename) as bz2f:
300            self.assertRaises(TypeError, bz2f.seek)
301            bz2f.seek(150)
302            self.assertEqual(bz2f.read(), self.TEXT[150:])
303
304    def testSeekForwardAcrossStreams(self):
305        self.createTempFile(streams=2)
306        with BZ2File(self.filename) as bz2f:
307            self.assertRaises(TypeError, bz2f.seek)
308            bz2f.seek(len(self.TEXT) + 150)
309            self.assertEqual(bz2f.read(), self.TEXT[150:])
310
311    def testSeekBackwards(self):
312        self.createTempFile()
313        with BZ2File(self.filename) as bz2f:
314            bz2f.read(500)
315            bz2f.seek(-150, 1)
316            self.assertEqual(bz2f.read(), self.TEXT[500-150:])
317
318    def testSeekBackwardsAcrossStreams(self):
319        self.createTempFile(streams=2)
320        with BZ2File(self.filename) as bz2f:
321            readto = len(self.TEXT) + 100
322            while readto > 0:
323                readto -= len(bz2f.read(readto))
324            bz2f.seek(-150, 1)
325            self.assertEqual(bz2f.read(), self.TEXT[100-150:] + self.TEXT)
326
327    def testSeekBackwardsFromEnd(self):
328        self.createTempFile()
329        with BZ2File(self.filename) as bz2f:
330            bz2f.seek(-150, 2)
331            self.assertEqual(bz2f.read(), self.TEXT[len(self.TEXT)-150:])
332
333    def testSeekBackwardsFromEndAcrossStreams(self):
334        self.createTempFile(streams=2)
335        with BZ2File(self.filename) as bz2f:
336            bz2f.seek(-1000, 2)
337            self.assertEqual(bz2f.read(), (self.TEXT * 2)[-1000:])
338
339    def testSeekPostEnd(self):
340        self.createTempFile()
341        with BZ2File(self.filename) as bz2f:
342            bz2f.seek(150000)
343            self.assertEqual(bz2f.tell(), len(self.TEXT))
344            self.assertEqual(bz2f.read(), b"")
345
346    def testSeekPostEndMultiStream(self):
347        self.createTempFile(streams=5)
348        with BZ2File(self.filename) as bz2f:
349            bz2f.seek(150000)
350            self.assertEqual(bz2f.tell(), len(self.TEXT) * 5)
351            self.assertEqual(bz2f.read(), b"")
352
353    def testSeekPostEndTwice(self):
354        self.createTempFile()
355        with BZ2File(self.filename) as bz2f:
356            bz2f.seek(150000)
357            bz2f.seek(150000)
358            self.assertEqual(bz2f.tell(), len(self.TEXT))
359            self.assertEqual(bz2f.read(), b"")
360
361    def testSeekPostEndTwiceMultiStream(self):
362        self.createTempFile(streams=5)
363        with BZ2File(self.filename) as bz2f:
364            bz2f.seek(150000)
365            bz2f.seek(150000)
366            self.assertEqual(bz2f.tell(), len(self.TEXT) * 5)
367            self.assertEqual(bz2f.read(), b"")
368
369    def testSeekPreStart(self):
370        self.createTempFile()
371        with BZ2File(self.filename) as bz2f:
372            bz2f.seek(-150)
373            self.assertEqual(bz2f.tell(), 0)
374            self.assertEqual(bz2f.read(), self.TEXT)
375
376    def testSeekPreStartMultiStream(self):
377        self.createTempFile(streams=2)
378        with BZ2File(self.filename) as bz2f:
379            bz2f.seek(-150)
380            self.assertEqual(bz2f.tell(), 0)
381            self.assertEqual(bz2f.read(), self.TEXT * 2)
382
383    def testFileno(self):
384        self.createTempFile()
385        with open(self.filename, 'rb') as rawf:
386            bz2f = BZ2File(rawf)
387            try:
388                self.assertEqual(bz2f.fileno(), rawf.fileno())
389            finally:
390                bz2f.close()
391        self.assertRaises(ValueError, bz2f.fileno)
392
393    def testSeekable(self):
394        bz2f = BZ2File(BytesIO(self.DATA))
395        try:
396            self.assertTrue(bz2f.seekable())
397            bz2f.read()
398            self.assertTrue(bz2f.seekable())
399        finally:
400            bz2f.close()
401        self.assertRaises(ValueError, bz2f.seekable)
402
403        bz2f = BZ2File(BytesIO(), "w")
404        try:
405            self.assertFalse(bz2f.seekable())
406        finally:
407            bz2f.close()
408        self.assertRaises(ValueError, bz2f.seekable)
409
410        src = BytesIO(self.DATA)
411        src.seekable = lambda: False
412        bz2f = BZ2File(src)
413        try:
414            self.assertFalse(bz2f.seekable())
415        finally:
416            bz2f.close()
417        self.assertRaises(ValueError, bz2f.seekable)
418
419    def testReadable(self):
420        bz2f = BZ2File(BytesIO(self.DATA))
421        try:
422            self.assertTrue(bz2f.readable())
423            bz2f.read()
424            self.assertTrue(bz2f.readable())
425        finally:
426            bz2f.close()
427        self.assertRaises(ValueError, bz2f.readable)
428
429        bz2f = BZ2File(BytesIO(), "w")
430        try:
431            self.assertFalse(bz2f.readable())
432        finally:
433            bz2f.close()
434        self.assertRaises(ValueError, bz2f.readable)
435
436    def testWritable(self):
437        bz2f = BZ2File(BytesIO(self.DATA))
438        try:
439            self.assertFalse(bz2f.writable())
440            bz2f.read()
441            self.assertFalse(bz2f.writable())
442        finally:
443            bz2f.close()
444        self.assertRaises(ValueError, bz2f.writable)
445
446        bz2f = BZ2File(BytesIO(), "w")
447        try:
448            self.assertTrue(bz2f.writable())
449        finally:
450            bz2f.close()
451        self.assertRaises(ValueError, bz2f.writable)
452
453    def testOpenDel(self):
454        self.createTempFile()
455        for i in range(10000):
456            o = BZ2File(self.filename)
457            del o
458
459    def testOpenNonexistent(self):
460        self.assertRaises(OSError, BZ2File, "/non/existent")
461
462    def testReadlinesNoNewline(self):
463        # Issue #1191043: readlines() fails on a file containing no newline.
464        data = b'BZh91AY&SY\xd9b\x89]\x00\x00\x00\x03\x80\x04\x00\x02\x00\x0c\x00 \x00!\x9ah3M\x13<]\xc9\x14\xe1BCe\x8a%t'
465        with open(self.filename, "wb") as f:
466            f.write(data)
467        with BZ2File(self.filename) as bz2f:
468            lines = bz2f.readlines()
469        self.assertEqual(lines, [b'Test'])
470        with BZ2File(self.filename) as bz2f:
471            xlines = list(bz2f.readlines())
472        self.assertEqual(xlines, [b'Test'])
473
474    def testContextProtocol(self):
475        f = None
476        with BZ2File(self.filename, "wb") as f:
477            f.write(b"xxx")
478        f = BZ2File(self.filename, "rb")
479        f.close()
480        try:
481            with f:
482                pass
483        except ValueError:
484            pass
485        else:
486            self.fail("__enter__ on a closed file didn't raise an exception")
487        try:
488            with BZ2File(self.filename, "wb") as f:
489                1/0
490        except ZeroDivisionError:
491            pass
492        else:
493            self.fail("1/0 didn't raise an exception")
494
495    @unittest.skipUnless(threading, 'Threading required for this test.')
496    def testThreading(self):
497        # Issue #7205: Using a BZ2File from several threads shouldn't deadlock.
498        data = b"1" * 2**20
499        nthreads = 10
500        with BZ2File(self.filename, 'wb') as f:
501            def comp():
502                for i in range(5):
503                    f.write(data)
504            threads = [threading.Thread(target=comp) for i in range(nthreads)]
505            with support.start_threads(threads):
506                pass
507
508    def testWithoutThreading(self):
509        module = support.import_fresh_module("bz2", blocked=("threading",))
510        with module.BZ2File(self.filename, "wb") as f:
511            f.write(b"abc")
512        with module.BZ2File(self.filename, "rb") as f:
513            self.assertEqual(f.read(), b"abc")
514
515    def testMixedIterationAndReads(self):
516        self.createTempFile()
517        linelen = len(self.TEXT_LINES[0])
518        halflen = linelen // 2
519        with BZ2File(self.filename) as bz2f:
520            bz2f.read(halflen)
521            self.assertEqual(next(bz2f), self.TEXT_LINES[0][halflen:])
522            self.assertEqual(bz2f.read(), self.TEXT[linelen:])
523        with BZ2File(self.filename) as bz2f:
524            bz2f.readline()
525            self.assertEqual(next(bz2f), self.TEXT_LINES[1])
526            self.assertEqual(bz2f.readline(), self.TEXT_LINES[2])
527        with BZ2File(self.filename) as bz2f:
528            bz2f.readlines()
529            self.assertRaises(StopIteration, next, bz2f)
530            self.assertEqual(bz2f.readlines(), [])
531
532    def testMultiStreamOrdering(self):
533        # Test the ordering of streams when reading a multi-stream archive.
534        data1 = b"foo" * 1000
535        data2 = b"bar" * 1000
536        with BZ2File(self.filename, "w") as bz2f:
537            bz2f.write(data1)
538        with BZ2File(self.filename, "a") as bz2f:
539            bz2f.write(data2)
540        with BZ2File(self.filename) as bz2f:
541            self.assertEqual(bz2f.read(), data1 + data2)
542
543    def testOpenBytesFilename(self):
544        str_filename = self.filename
545        try:
546            bytes_filename = str_filename.encode("ascii")
547        except UnicodeEncodeError:
548            self.skipTest("Temporary file name needs to be ASCII")
549        with BZ2File(bytes_filename, "wb") as f:
550            f.write(self.DATA)
551        with BZ2File(bytes_filename, "rb") as f:
552            self.assertEqual(f.read(), self.DATA)
553        # Sanity check that we are actually operating on the right file.
554        with BZ2File(str_filename, "rb") as f:
555            self.assertEqual(f.read(), self.DATA)
556
557    def testOpenPathLikeFilename(self):
558        filename = pathlib.Path(self.filename)
559        with BZ2File(filename, "wb") as f:
560            f.write(self.DATA)
561        with BZ2File(filename, "rb") as f:
562            self.assertEqual(f.read(), self.DATA)
563
564    def testDecompressLimited(self):
565        """Decompressed data buffering should be limited"""
566        bomb = bz2.compress(b'\0' * int(2e6), compresslevel=9)
567        self.assertLess(len(bomb), _compression.BUFFER_SIZE)
568
569        decomp = BZ2File(BytesIO(bomb))
570        self.assertEqual(decomp.read(1), b'\0')
571        max_decomp = 1 + DEFAULT_BUFFER_SIZE
572        self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp,
573            "Excessive amount of data was decompressed")
574
575
576    # Tests for a BZ2File wrapping another file object:
577
578    def testReadBytesIO(self):
579        with BytesIO(self.DATA) as bio:
580            with BZ2File(bio) as bz2f:
581                self.assertRaises(TypeError, bz2f.read, float())
582                self.assertEqual(bz2f.read(), self.TEXT)
583            self.assertFalse(bio.closed)
584
585    def testPeekBytesIO(self):
586        with BytesIO(self.DATA) as bio:
587            with BZ2File(bio) as bz2f:
588                pdata = bz2f.peek()
589                self.assertNotEqual(len(pdata), 0)
590                self.assertTrue(self.TEXT.startswith(pdata))
591                self.assertEqual(bz2f.read(), self.TEXT)
592
593    def testWriteBytesIO(self):
594        with BytesIO() as bio:
595            with BZ2File(bio, "w") as bz2f:
596                self.assertRaises(TypeError, bz2f.write)
597                bz2f.write(self.TEXT)
598            self.assertEqual(ext_decompress(bio.getvalue()), self.TEXT)
599            self.assertFalse(bio.closed)
600
601    def testSeekForwardBytesIO(self):
602        with BytesIO(self.DATA) as bio:
603            with BZ2File(bio) as bz2f:
604                self.assertRaises(TypeError, bz2f.seek)
605                bz2f.seek(150)
606                self.assertEqual(bz2f.read(), self.TEXT[150:])
607
608    def testSeekBackwardsBytesIO(self):
609        with BytesIO(self.DATA) as bio:
610            with BZ2File(bio) as bz2f:
611                bz2f.read(500)
612                bz2f.seek(-150, 1)
613                self.assertEqual(bz2f.read(), self.TEXT[500-150:])
614
615    def test_read_truncated(self):
616        # Drop the eos_magic field (6 bytes) and CRC (4 bytes).
617        truncated = self.DATA[:-10]
618        with BZ2File(BytesIO(truncated)) as f:
619            self.assertRaises(EOFError, f.read)
620        with BZ2File(BytesIO(truncated)) as f:
621            self.assertEqual(f.read(len(self.TEXT)), self.TEXT)
622            self.assertRaises(EOFError, f.read, 1)
623        # Incomplete 4-byte file header, and block header of at least 146 bits.
624        for i in range(22):
625            with BZ2File(BytesIO(truncated[:i])) as f:
626                self.assertRaises(EOFError, f.read, 1)
627
628
629class BZ2CompressorTest(BaseTest):
630    def testCompress(self):
631        bz2c = BZ2Compressor()
632        self.assertRaises(TypeError, bz2c.compress)
633        data = bz2c.compress(self.TEXT)
634        data += bz2c.flush()
635        self.assertEqual(ext_decompress(data), self.TEXT)
636
637    def testCompressEmptyString(self):
638        bz2c = BZ2Compressor()
639        data = bz2c.compress(b'')
640        data += bz2c.flush()
641        self.assertEqual(data, self.EMPTY_DATA)
642
643    def testCompressChunks10(self):
644        bz2c = BZ2Compressor()
645        n = 0
646        data = b''
647        while True:
648            str = self.TEXT[n*10:(n+1)*10]
649            if not str:
650                break
651            data += bz2c.compress(str)
652            n += 1
653        data += bz2c.flush()
654        self.assertEqual(ext_decompress(data), self.TEXT)
655
656    @bigmemtest(size=_4G + 100, memuse=2)
657    def testCompress4G(self, size):
658        # "Test BZ2Compressor.compress()/flush() with >4GiB input"
659        bz2c = BZ2Compressor()
660        data = b"x" * size
661        try:
662            compressed = bz2c.compress(data)
663            compressed += bz2c.flush()
664        finally:
665            data = None  # Release memory
666        data = bz2.decompress(compressed)
667        try:
668            self.assertEqual(len(data), size)
669            self.assertEqual(len(data.strip(b"x")), 0)
670        finally:
671            data = None
672
673    def testPickle(self):
674        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
675            with self.assertRaises(TypeError):
676                pickle.dumps(BZ2Compressor(), proto)
677
678
679class BZ2DecompressorTest(BaseTest):
680    def test_Constructor(self):
681        self.assertRaises(TypeError, BZ2Decompressor, 42)
682
683    def testDecompress(self):
684        bz2d = BZ2Decompressor()
685        self.assertRaises(TypeError, bz2d.decompress)
686        text = bz2d.decompress(self.DATA)
687        self.assertEqual(text, self.TEXT)
688
689    def testDecompressChunks10(self):
690        bz2d = BZ2Decompressor()
691        text = b''
692        n = 0
693        while True:
694            str = self.DATA[n*10:(n+1)*10]
695            if not str:
696                break
697            text += bz2d.decompress(str)
698            n += 1
699        self.assertEqual(text, self.TEXT)
700
701    def testDecompressUnusedData(self):
702        bz2d = BZ2Decompressor()
703        unused_data = b"this is unused data"
704        text = bz2d.decompress(self.DATA+unused_data)
705        self.assertEqual(text, self.TEXT)
706        self.assertEqual(bz2d.unused_data, unused_data)
707
708    def testEOFError(self):
709        bz2d = BZ2Decompressor()
710        text = bz2d.decompress(self.DATA)
711        self.assertRaises(EOFError, bz2d.decompress, b"anything")
712        self.assertRaises(EOFError, bz2d.decompress, b"")
713
714    @bigmemtest(size=_4G + 100, memuse=3.3)
715    def testDecompress4G(self, size):
716        # "Test BZ2Decompressor.decompress() with >4GiB input"
717        blocksize = 10 * 1024 * 1024
718        block = random.getrandbits(blocksize * 8).to_bytes(blocksize, 'little')
719        try:
720            data = block * (size // blocksize + 1)
721            compressed = bz2.compress(data)
722            bz2d = BZ2Decompressor()
723            decompressed = bz2d.decompress(compressed)
724            self.assertTrue(decompressed == data)
725        finally:
726            data = None
727            compressed = None
728            decompressed = None
729
730    def testPickle(self):
731        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
732            with self.assertRaises(TypeError):
733                pickle.dumps(BZ2Decompressor(), proto)
734
735    def testDecompressorChunksMaxsize(self):
736        bzd = BZ2Decompressor()
737        max_length = 100
738        out = []
739
740        # Feed some input
741        len_ = len(self.BIG_DATA) - 64
742        out.append(bzd.decompress(self.BIG_DATA[:len_],
743                                  max_length=max_length))
744        self.assertFalse(bzd.needs_input)
745        self.assertEqual(len(out[-1]), max_length)
746
747        # Retrieve more data without providing more input
748        out.append(bzd.decompress(b'', max_length=max_length))
749        self.assertFalse(bzd.needs_input)
750        self.assertEqual(len(out[-1]), max_length)
751
752        # Retrieve more data while providing more input
753        out.append(bzd.decompress(self.BIG_DATA[len_:],
754                                  max_length=max_length))
755        self.assertLessEqual(len(out[-1]), max_length)
756
757        # Retrieve remaining uncompressed data
758        while not bzd.eof:
759            out.append(bzd.decompress(b'', max_length=max_length))
760            self.assertLessEqual(len(out[-1]), max_length)
761
762        out = b"".join(out)
763        self.assertEqual(out, self.BIG_TEXT)
764        self.assertEqual(bzd.unused_data, b"")
765
766    def test_decompressor_inputbuf_1(self):
767        # Test reusing input buffer after moving existing
768        # contents to beginning
769        bzd = BZ2Decompressor()
770        out = []
771
772        # Create input buffer and fill it
773        self.assertEqual(bzd.decompress(self.DATA[:100],
774                                        max_length=0), b'')
775
776        # Retrieve some results, freeing capacity at beginning
777        # of input buffer
778        out.append(bzd.decompress(b'', 2))
779
780        # Add more data that fits into input buffer after
781        # moving existing data to beginning
782        out.append(bzd.decompress(self.DATA[100:105], 15))
783
784        # Decompress rest of data
785        out.append(bzd.decompress(self.DATA[105:]))
786        self.assertEqual(b''.join(out), self.TEXT)
787
788    def test_decompressor_inputbuf_2(self):
789        # Test reusing input buffer by appending data at the
790        # end right away
791        bzd = BZ2Decompressor()
792        out = []
793
794        # Create input buffer and empty it
795        self.assertEqual(bzd.decompress(self.DATA[:200],
796                                        max_length=0), b'')
797        out.append(bzd.decompress(b''))
798
799        # Fill buffer with new data
800        out.append(bzd.decompress(self.DATA[200:280], 2))
801
802        # Append some more data, not enough to require resize
803        out.append(bzd.decompress(self.DATA[280:300], 2))
804
805        # Decompress rest of data
806        out.append(bzd.decompress(self.DATA[300:]))
807        self.assertEqual(b''.join(out), self.TEXT)
808
809    def test_decompressor_inputbuf_3(self):
810        # Test reusing input buffer after extending it
811
812        bzd = BZ2Decompressor()
813        out = []
814
815        # Create almost full input buffer
816        out.append(bzd.decompress(self.DATA[:200], 5))
817
818        # Add even more data to it, requiring resize
819        out.append(bzd.decompress(self.DATA[200:300], 5))
820
821        # Decompress rest of data
822        out.append(bzd.decompress(self.DATA[300:]))
823        self.assertEqual(b''.join(out), self.TEXT)
824
825    def test_failure(self):
826        bzd = BZ2Decompressor()
827        self.assertRaises(Exception, bzd.decompress, self.BAD_DATA * 30)
828        # Previously, a second call could crash due to internal inconsistency
829        self.assertRaises(Exception, bzd.decompress, self.BAD_DATA * 30)
830
831class CompressDecompressTest(BaseTest):
832    def testCompress(self):
833        data = bz2.compress(self.TEXT)
834        self.assertEqual(ext_decompress(data), self.TEXT)
835
836    def testCompressEmptyString(self):
837        text = bz2.compress(b'')
838        self.assertEqual(text, self.EMPTY_DATA)
839
840    def testDecompress(self):
841        text = bz2.decompress(self.DATA)
842        self.assertEqual(text, self.TEXT)
843
844    def testDecompressEmpty(self):
845        text = bz2.decompress(b"")
846        self.assertEqual(text, b"")
847
848    def testDecompressToEmptyString(self):
849        text = bz2.decompress(self.EMPTY_DATA)
850        self.assertEqual(text, b'')
851
852    def testDecompressIncomplete(self):
853        self.assertRaises(ValueError, bz2.decompress, self.DATA[:-10])
854
855    def testDecompressBadData(self):
856        self.assertRaises(OSError, bz2.decompress, self.BAD_DATA)
857
858    def testDecompressMultiStream(self):
859        text = bz2.decompress(self.DATA * 5)
860        self.assertEqual(text, self.TEXT * 5)
861
862    def testDecompressTrailingJunk(self):
863        text = bz2.decompress(self.DATA + self.BAD_DATA)
864        self.assertEqual(text, self.TEXT)
865
866    def testDecompressMultiStreamTrailingJunk(self):
867        text = bz2.decompress(self.DATA * 5 + self.BAD_DATA)
868        self.assertEqual(text, self.TEXT * 5)
869
870
871class OpenTest(BaseTest):
872    "Test the open function."
873
874    def open(self, *args, **kwargs):
875        return bz2.open(*args, **kwargs)
876
877    def test_binary_modes(self):
878        for mode in ("wb", "xb"):
879            if mode == "xb":
880                unlink(self.filename)
881            with self.open(self.filename, mode) as f:
882                f.write(self.TEXT)
883            with open(self.filename, "rb") as f:
884                file_data = ext_decompress(f.read())
885                self.assertEqual(file_data, self.TEXT)
886            with self.open(self.filename, "rb") as f:
887                self.assertEqual(f.read(), self.TEXT)
888            with self.open(self.filename, "ab") as f:
889                f.write(self.TEXT)
890            with open(self.filename, "rb") as f:
891                file_data = ext_decompress(f.read())
892                self.assertEqual(file_data, self.TEXT * 2)
893
894    def test_implicit_binary_modes(self):
895        # Test implicit binary modes (no "b" or "t" in mode string).
896        for mode in ("w", "x"):
897            if mode == "x":
898                unlink(self.filename)
899            with self.open(self.filename, mode) as f:
900                f.write(self.TEXT)
901            with open(self.filename, "rb") as f:
902                file_data = ext_decompress(f.read())
903                self.assertEqual(file_data, self.TEXT)
904            with self.open(self.filename, "r") as f:
905                self.assertEqual(f.read(), self.TEXT)
906            with self.open(self.filename, "a") as f:
907                f.write(self.TEXT)
908            with open(self.filename, "rb") as f:
909                file_data = ext_decompress(f.read())
910                self.assertEqual(file_data, self.TEXT * 2)
911
912    def test_text_modes(self):
913        text = self.TEXT.decode("ascii")
914        text_native_eol = text.replace("\n", os.linesep)
915        for mode in ("wt", "xt"):
916            if mode == "xt":
917                unlink(self.filename)
918            with self.open(self.filename, mode) as f:
919                f.write(text)
920            with open(self.filename, "rb") as f:
921                file_data = ext_decompress(f.read()).decode("ascii")
922                self.assertEqual(file_data, text_native_eol)
923            with self.open(self.filename, "rt") as f:
924                self.assertEqual(f.read(), text)
925            with self.open(self.filename, "at") as f:
926                f.write(text)
927            with open(self.filename, "rb") as f:
928                file_data = ext_decompress(f.read()).decode("ascii")
929                self.assertEqual(file_data, text_native_eol * 2)
930
931    def test_x_mode(self):
932        for mode in ("x", "xb", "xt"):
933            unlink(self.filename)
934            with self.open(self.filename, mode) as f:
935                pass
936            with self.assertRaises(FileExistsError):
937                with self.open(self.filename, mode) as f:
938                    pass
939
940    def test_fileobj(self):
941        with self.open(BytesIO(self.DATA), "r") as f:
942            self.assertEqual(f.read(), self.TEXT)
943        with self.open(BytesIO(self.DATA), "rb") as f:
944            self.assertEqual(f.read(), self.TEXT)
945        text = self.TEXT.decode("ascii")
946        with self.open(BytesIO(self.DATA), "rt") as f:
947            self.assertEqual(f.read(), text)
948
949    def test_bad_params(self):
950        # Test invalid parameter combinations.
951        self.assertRaises(ValueError,
952                          self.open, self.filename, "wbt")
953        self.assertRaises(ValueError,
954                          self.open, self.filename, "xbt")
955        self.assertRaises(ValueError,
956                          self.open, self.filename, "rb", encoding="utf-8")
957        self.assertRaises(ValueError,
958                          self.open, self.filename, "rb", errors="ignore")
959        self.assertRaises(ValueError,
960                          self.open, self.filename, "rb", newline="\n")
961
962    def test_encoding(self):
963        # Test non-default encoding.
964        text = self.TEXT.decode("ascii")
965        text_native_eol = text.replace("\n", os.linesep)
966        with self.open(self.filename, "wt", encoding="utf-16-le") as f:
967            f.write(text)
968        with open(self.filename, "rb") as f:
969            file_data = ext_decompress(f.read()).decode("utf-16-le")
970            self.assertEqual(file_data, text_native_eol)
971        with self.open(self.filename, "rt", encoding="utf-16-le") as f:
972            self.assertEqual(f.read(), text)
973
974    def test_encoding_error_handler(self):
975        # Test with non-default encoding error handler.
976        with self.open(self.filename, "wb") as f:
977            f.write(b"foo\xffbar")
978        with self.open(self.filename, "rt", encoding="ascii", errors="ignore") \
979                as f:
980            self.assertEqual(f.read(), "foobar")
981
982    def test_newline(self):
983        # Test with explicit newline (universal newline mode disabled).
984        text = self.TEXT.decode("ascii")
985        with self.open(self.filename, "wt", newline="\n") as f:
986            f.write(text)
987        with self.open(self.filename, "rt", newline="\r") as f:
988            self.assertEqual(f.readlines(), [text])
989
990
991def test_main():
992    support.run_unittest(
993        BZ2FileTest,
994        BZ2CompressorTest,
995        BZ2DecompressorTest,
996        CompressDecompressTest,
997        OpenTest,
998    )
999    support.reap_children()
1000
1001if __name__ == '__main__':
1002    test_main()
1003