1"""Unittests for heapq."""
2
3import sys
4import random
5
6from test import test_support
7from unittest import TestCase, skipUnless
8
9py_heapq = test_support.import_fresh_module('heapq', blocked=['_heapq'])
10c_heapq = test_support.import_fresh_module('heapq', fresh=['_heapq'])
11
12# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
13# _heapq is imported, so check them there
14func_names = ['heapify', 'heappop', 'heappush', 'heappushpop',
15              'heapreplace', '_nlargest', '_nsmallest']
16
17class TestModules(TestCase):
18    def test_py_functions(self):
19        for fname in func_names:
20            self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
21
22    @skipUnless(c_heapq, 'requires _heapq')
23    def test_c_functions(self):
24        for fname in func_names:
25            self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
26
27
28class TestHeap(TestCase):
29    module = None
30
31    def test_push_pop(self):
32        # 1) Push 256 random numbers and pop them off, verifying all's OK.
33        heap = []
34        data = []
35        self.check_invariant(heap)
36        for i in range(256):
37            item = random.random()
38            data.append(item)
39            self.module.heappush(heap, item)
40            self.check_invariant(heap)
41        results = []
42        while heap:
43            item = self.module.heappop(heap)
44            self.check_invariant(heap)
45            results.append(item)
46        data_sorted = data[:]
47        data_sorted.sort()
48        self.assertEqual(data_sorted, results)
49        # 2) Check that the invariant holds for a sorted array
50        self.check_invariant(results)
51
52        self.assertRaises(TypeError, self.module.heappush, [])
53        try:
54            self.assertRaises(TypeError, self.module.heappush, None, None)
55            self.assertRaises(TypeError, self.module.heappop, None)
56        except AttributeError:
57            pass
58
59    def check_invariant(self, heap):
60        # Check the heap invariant.
61        for pos, item in enumerate(heap):
62            if pos: # pos 0 has no parent
63                parentpos = (pos-1) >> 1
64                self.assertTrue(heap[parentpos] <= item)
65
66    def test_heapify(self):
67        for size in range(30):
68            heap = [random.random() for dummy in range(size)]
69            self.module.heapify(heap)
70            self.check_invariant(heap)
71
72        self.assertRaises(TypeError, self.module.heapify, None)
73
74    def test_naive_nbest(self):
75        data = [random.randrange(2000) for i in range(1000)]
76        heap = []
77        for item in data:
78            self.module.heappush(heap, item)
79            if len(heap) > 10:
80                self.module.heappop(heap)
81        heap.sort()
82        self.assertEqual(heap, sorted(data)[-10:])
83
84    def heapiter(self, heap):
85        # An iterator returning a heap's elements, smallest-first.
86        try:
87            while 1:
88                yield self.module.heappop(heap)
89        except IndexError:
90            pass
91
92    def test_nbest(self):
93        # Less-naive "N-best" algorithm, much faster (if len(data) is big
94        # enough <wink>) than sorting all of data.  However, if we had a max
95        # heap instead of a min heap, it could go faster still via
96        # heapify'ing all of data (linear time), then doing 10 heappops
97        # (10 log-time steps).
98        data = [random.randrange(2000) for i in range(1000)]
99        heap = data[:10]
100        self.module.heapify(heap)
101        for item in data[10:]:
102            if item > heap[0]:  # this gets rarer the longer we run
103                self.module.heapreplace(heap, item)
104        self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
105
106        self.assertRaises(TypeError, self.module.heapreplace, None)
107        self.assertRaises(TypeError, self.module.heapreplace, None, None)
108        self.assertRaises(IndexError, self.module.heapreplace, [], None)
109
110    def test_nbest_with_pushpop(self):
111        data = [random.randrange(2000) for i in range(1000)]
112        heap = data[:10]
113        self.module.heapify(heap)
114        for item in data[10:]:
115            self.module.heappushpop(heap, item)
116        self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
117        self.assertEqual(self.module.heappushpop([], 'x'), 'x')
118
119    def test_heappushpop(self):
120        h = []
121        x = self.module.heappushpop(h, 10)
122        self.assertEqual((h, x), ([], 10))
123
124        h = [10]
125        x = self.module.heappushpop(h, 10.0)
126        self.assertEqual((h, x), ([10], 10.0))
127        self.assertEqual(type(h[0]), int)
128        self.assertEqual(type(x), float)
129
130        h = [10];
131        x = self.module.heappushpop(h, 9)
132        self.assertEqual((h, x), ([10], 9))
133
134        h = [10];
135        x = self.module.heappushpop(h, 11)
136        self.assertEqual((h, x), ([11], 10))
137
138    def test_heapsort(self):
139        # Exercise everything with repeated heapsort checks
140        for trial in xrange(100):
141            size = random.randrange(50)
142            data = [random.randrange(25) for i in range(size)]
143            if trial & 1:     # Half of the time, use heapify
144                heap = data[:]
145                self.module.heapify(heap)
146            else:             # The rest of the time, use heappush
147                heap = []
148                for item in data:
149                    self.module.heappush(heap, item)
150            heap_sorted = [self.module.heappop(heap) for i in range(size)]
151            self.assertEqual(heap_sorted, sorted(data))
152
153    def test_merge(self):
154        inputs = []
155        for i in xrange(random.randrange(5)):
156            row = sorted(random.randrange(1000) for j in range(random.randrange(10)))
157            inputs.append(row)
158        self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs)))
159        self.assertEqual(list(self.module.merge()), [])
160
161    def test_merge_does_not_suppress_index_error(self):
162        # Issue 19018: Heapq.merge suppresses IndexError from user generator
163        def iterable():
164            s = list(range(10))
165            for i in range(20):
166                yield s[i]       # IndexError when i > 10
167        with self.assertRaises(IndexError):
168            list(self.module.merge(iterable(), iterable()))
169
170    def test_merge_stability(self):
171        class Int(int):
172            pass
173        inputs = [[], [], [], []]
174        for i in range(20000):
175            stream = random.randrange(4)
176            x = random.randrange(500)
177            obj = Int(x)
178            obj.pair = (x, stream)
179            inputs[stream].append(obj)
180        for stream in inputs:
181            stream.sort()
182        result = [i.pair for i in self.module.merge(*inputs)]
183        self.assertEqual(result, sorted(result))
184
185    def test_nsmallest(self):
186        data = [(random.randrange(2000), i) for i in range(1000)]
187        for f in (None, lambda x:  x[0] * 547 % 2000):
188            for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
189                self.assertEqual(self.module.nsmallest(n, data), sorted(data)[:n])
190                self.assertEqual(self.module.nsmallest(n, data, key=f),
191                                 sorted(data, key=f)[:n])
192
193    def test_nlargest(self):
194        data = [(random.randrange(2000), i) for i in range(1000)]
195        for f in (None, lambda x:  x[0] * 547 % 2000):
196            for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
197                self.assertEqual(self.module.nlargest(n, data),
198                                 sorted(data, reverse=True)[:n])
199                self.assertEqual(self.module.nlargest(n, data, key=f),
200                                 sorted(data, key=f, reverse=True)[:n])
201
202    def test_comparison_operator(self):
203        # Issue 3051: Make sure heapq works with both __lt__ and __le__
204        def hsort(data, comp):
205            data = map(comp, data)
206            self.module.heapify(data)
207            return [self.module.heappop(data).x for i in range(len(data))]
208        class LT:
209            def __init__(self, x):
210                self.x = x
211            def __lt__(self, other):
212                return self.x > other.x
213        class LE:
214            def __init__(self, x):
215                self.x = x
216            def __le__(self, other):
217                return self.x >= other.x
218        data = [random.random() for i in range(100)]
219        target = sorted(data, reverse=True)
220        self.assertEqual(hsort(data, LT), target)
221        self.assertEqual(hsort(data, LE), target)
222
223
224class TestHeapPython(TestHeap):
225    module = py_heapq
226
227
228@skipUnless(c_heapq, 'requires _heapq')
229class TestHeapC(TestHeap):
230    module = c_heapq
231
232
233#==============================================================================
234
235class LenOnly:
236    "Dummy sequence class defining __len__ but not __getitem__."
237    def __len__(self):
238        return 10
239
240class GetOnly:
241    "Dummy sequence class defining __getitem__ but not __len__."
242    def __getitem__(self, ndx):
243        return 10
244
245class CmpErr:
246    "Dummy element that always raises an error during comparison"
247    def __cmp__(self, other):
248        raise ZeroDivisionError
249
250def R(seqn):
251    'Regular generator'
252    for i in seqn:
253        yield i
254
255class G:
256    'Sequence using __getitem__'
257    def __init__(self, seqn):
258        self.seqn = seqn
259    def __getitem__(self, i):
260        return self.seqn[i]
261
262class I:
263    'Sequence using iterator protocol'
264    def __init__(self, seqn):
265        self.seqn = seqn
266        self.i = 0
267    def __iter__(self):
268        return self
269    def next(self):
270        if self.i >= len(self.seqn): raise StopIteration
271        v = self.seqn[self.i]
272        self.i += 1
273        return v
274
275class Ig:
276    'Sequence using iterator protocol defined with a generator'
277    def __init__(self, seqn):
278        self.seqn = seqn
279        self.i = 0
280    def __iter__(self):
281        for val in self.seqn:
282            yield val
283
284class X:
285    'Missing __getitem__ and __iter__'
286    def __init__(self, seqn):
287        self.seqn = seqn
288        self.i = 0
289    def next(self):
290        if self.i >= len(self.seqn): raise StopIteration
291        v = self.seqn[self.i]
292        self.i += 1
293        return v
294
295class N:
296    'Iterator missing next()'
297    def __init__(self, seqn):
298        self.seqn = seqn
299        self.i = 0
300    def __iter__(self):
301        return self
302
303class E:
304    'Test propagation of exceptions'
305    def __init__(self, seqn):
306        self.seqn = seqn
307        self.i = 0
308    def __iter__(self):
309        return self
310    def next(self):
311        3 // 0
312
313class S:
314    'Test immediate stop'
315    def __init__(self, seqn):
316        pass
317    def __iter__(self):
318        return self
319    def next(self):
320        raise StopIteration
321
322from itertools import chain, imap
323def L(seqn):
324    'Test multiple tiers of iterators'
325    return chain(imap(lambda x:x, R(Ig(G(seqn)))))
326
327class SideEffectLT:
328    def __init__(self, value, heap):
329        self.value = value
330        self.heap = heap
331
332    def __lt__(self, other):
333        self.heap[:] = []
334        return self.value < other.value
335
336
337class TestErrorHandling(TestCase):
338    module = None
339
340    def test_non_sequence(self):
341        for f in (self.module.heapify, self.module.heappop):
342            self.assertRaises((TypeError, AttributeError), f, 10)
343        for f in (self.module.heappush, self.module.heapreplace,
344                  self.module.nlargest, self.module.nsmallest):
345            self.assertRaises((TypeError, AttributeError), f, 10, 10)
346
347    def test_len_only(self):
348        for f in (self.module.heapify, self.module.heappop):
349            self.assertRaises((TypeError, AttributeError), f, LenOnly())
350        for f in (self.module.heappush, self.module.heapreplace):
351            self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)
352        for f in (self.module.nlargest, self.module.nsmallest):
353            self.assertRaises(TypeError, f, 2, LenOnly())
354
355    def test_get_only(self):
356        seq = [CmpErr(), CmpErr(), CmpErr()]
357        for f in (self.module.heapify, self.module.heappop):
358            self.assertRaises(ZeroDivisionError, f, seq)
359        for f in (self.module.heappush, self.module.heapreplace):
360            self.assertRaises(ZeroDivisionError, f, seq, 10)
361        for f in (self.module.nlargest, self.module.nsmallest):
362            self.assertRaises(ZeroDivisionError, f, 2, seq)
363
364    def test_arg_parsing(self):
365        for f in (self.module.heapify, self.module.heappop,
366                  self.module.heappush, self.module.heapreplace,
367                  self.module.nlargest, self.module.nsmallest):
368            self.assertRaises((TypeError, AttributeError), f, 10)
369
370    def test_iterable_args(self):
371        for f in (self.module.nlargest, self.module.nsmallest):
372            for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
373                for g in (G, I, Ig, L, R):
374                    with test_support.check_py3k_warnings(
375                            ("comparing unequal types not supported",
376                             DeprecationWarning), quiet=True):
377                        self.assertEqual(f(2, g(s)), f(2,s))
378                self.assertEqual(f(2, S(s)), [])
379                self.assertRaises(TypeError, f, 2, X(s))
380                self.assertRaises(TypeError, f, 2, N(s))
381                self.assertRaises(ZeroDivisionError, f, 2, E(s))
382
383    # Issue #17278: the heap may change size while it's being walked.
384
385    def test_heappush_mutating_heap(self):
386        heap = []
387        heap.extend(SideEffectLT(i, heap) for i in range(200))
388        # Python version raises IndexError, C version RuntimeError
389        with self.assertRaises((IndexError, RuntimeError)):
390            self.module.heappush(heap, SideEffectLT(5, heap))
391
392    def test_heappop_mutating_heap(self):
393        heap = []
394        heap.extend(SideEffectLT(i, heap) for i in range(200))
395        # Python version raises IndexError, C version RuntimeError
396        with self.assertRaises((IndexError, RuntimeError)):
397            self.module.heappop(heap)
398
399
400class TestErrorHandlingPython(TestErrorHandling):
401    module = py_heapq
402
403
404@skipUnless(c_heapq, 'requires _heapq')
405class TestErrorHandlingC(TestErrorHandling):
406    module = c_heapq
407
408
409#==============================================================================
410
411
412def test_main(verbose=None):
413    test_classes = [TestModules, TestHeapPython, TestHeapC,
414                    TestErrorHandlingPython, TestErrorHandlingC]
415    test_support.run_unittest(*test_classes)
416
417    # verify reference counting
418    if verbose and hasattr(sys, "gettotalrefcount"):
419        import gc
420        counts = [None] * 5
421        for i in xrange(len(counts)):
422            test_support.run_unittest(*test_classes)
423            gc.collect()
424            counts[i] = sys.gettotalrefcount()
425        print counts
426
427if __name__ == "__main__":
428    test_main(verbose=True)
429