1"""Unittests for heapq.""" 2 3import random 4import unittest 5import doctest 6 7from test import support 8from unittest import TestCase, skipUnless 9from operator import itemgetter 10 11py_heapq = support.import_fresh_module('heapq', blocked=['_heapq']) 12c_heapq = support.import_fresh_module('heapq', fresh=['_heapq']) 13 14# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when 15# _heapq is imported, so check them there 16func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace', 17 '_heappop_max', '_heapreplace_max', '_heapify_max'] 18 19class TestModules(TestCase): 20 def test_py_functions(self): 21 for fname in func_names: 22 self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq') 23 24 @skipUnless(c_heapq, 'requires _heapq') 25 def test_c_functions(self): 26 for fname in func_names: 27 self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq') 28 29 30def load_tests(loader, tests, ignore): 31 # The 'merge' function has examples in its docstring which we should test 32 # with 'doctest'. 33 # 34 # However, doctest can't easily find all docstrings in the module (loading 35 # it through import_fresh_module seems to confuse it), so we specifically 36 # create a finder which returns the doctests from the merge method. 37 38 class HeapqMergeDocTestFinder: 39 def find(self, *args, **kwargs): 40 dtf = doctest.DocTestFinder() 41 return dtf.find(py_heapq.merge) 42 43 tests.addTests(doctest.DocTestSuite(py_heapq, 44 test_finder=HeapqMergeDocTestFinder())) 45 return tests 46 47class TestHeap: 48 49 def test_push_pop(self): 50 # 1) Push 256 random numbers and pop them off, verifying all's OK. 51 heap = [] 52 data = [] 53 self.check_invariant(heap) 54 for i in range(256): 55 item = random.random() 56 data.append(item) 57 self.module.heappush(heap, item) 58 self.check_invariant(heap) 59 results = [] 60 while heap: 61 item = self.module.heappop(heap) 62 self.check_invariant(heap) 63 results.append(item) 64 data_sorted = data[:] 65 data_sorted.sort() 66 self.assertEqual(data_sorted, results) 67 # 2) Check that the invariant holds for a sorted array 68 self.check_invariant(results) 69 70 self.assertRaises(TypeError, self.module.heappush, []) 71 try: 72 self.assertRaises(TypeError, self.module.heappush, None, None) 73 self.assertRaises(TypeError, self.module.heappop, None) 74 except AttributeError: 75 pass 76 77 def check_invariant(self, heap): 78 # Check the heap invariant. 79 for pos, item in enumerate(heap): 80 if pos: # pos 0 has no parent 81 parentpos = (pos-1) >> 1 82 self.assertTrue(heap[parentpos] <= item) 83 84 def test_heapify(self): 85 for size in list(range(30)) + [20000]: 86 heap = [random.random() for dummy in range(size)] 87 self.module.heapify(heap) 88 self.check_invariant(heap) 89 90 self.assertRaises(TypeError, self.module.heapify, None) 91 92 def test_naive_nbest(self): 93 data = [random.randrange(2000) for i in range(1000)] 94 heap = [] 95 for item in data: 96 self.module.heappush(heap, item) 97 if len(heap) > 10: 98 self.module.heappop(heap) 99 heap.sort() 100 self.assertEqual(heap, sorted(data)[-10:]) 101 102 def heapiter(self, heap): 103 # An iterator returning a heap's elements, smallest-first. 104 try: 105 while 1: 106 yield self.module.heappop(heap) 107 except IndexError: 108 pass 109 110 def test_nbest(self): 111 # Less-naive "N-best" algorithm, much faster (if len(data) is big 112 # enough <wink>) than sorting all of data. However, if we had a max 113 # heap instead of a min heap, it could go faster still via 114 # heapify'ing all of data (linear time), then doing 10 heappops 115 # (10 log-time steps). 116 data = [random.randrange(2000) for i in range(1000)] 117 heap = data[:10] 118 self.module.heapify(heap) 119 for item in data[10:]: 120 if item > heap[0]: # this gets rarer the longer we run 121 self.module.heapreplace(heap, item) 122 self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:]) 123 124 self.assertRaises(TypeError, self.module.heapreplace, None) 125 self.assertRaises(TypeError, self.module.heapreplace, None, None) 126 self.assertRaises(IndexError, self.module.heapreplace, [], None) 127 128 def test_nbest_with_pushpop(self): 129 data = [random.randrange(2000) for i in range(1000)] 130 heap = data[:10] 131 self.module.heapify(heap) 132 for item in data[10:]: 133 self.module.heappushpop(heap, item) 134 self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:]) 135 self.assertEqual(self.module.heappushpop([], 'x'), 'x') 136 137 def test_heappushpop(self): 138 h = [] 139 x = self.module.heappushpop(h, 10) 140 self.assertEqual((h, x), ([], 10)) 141 142 h = [10] 143 x = self.module.heappushpop(h, 10.0) 144 self.assertEqual((h, x), ([10], 10.0)) 145 self.assertEqual(type(h[0]), int) 146 self.assertEqual(type(x), float) 147 148 h = [10]; 149 x = self.module.heappushpop(h, 9) 150 self.assertEqual((h, x), ([10], 9)) 151 152 h = [10]; 153 x = self.module.heappushpop(h, 11) 154 self.assertEqual((h, x), ([11], 10)) 155 156 def test_heappop_max(self): 157 # _heapop_max has an optimization for one-item lists which isn't 158 # covered in other tests, so test that case explicitly here 159 h = [3, 2] 160 self.assertEqual(self.module._heappop_max(h), 3) 161 self.assertEqual(self.module._heappop_max(h), 2) 162 163 def test_heapsort(self): 164 # Exercise everything with repeated heapsort checks 165 for trial in range(100): 166 size = random.randrange(50) 167 data = [random.randrange(25) for i in range(size)] 168 if trial & 1: # Half of the time, use heapify 169 heap = data[:] 170 self.module.heapify(heap) 171 else: # The rest of the time, use heappush 172 heap = [] 173 for item in data: 174 self.module.heappush(heap, item) 175 heap_sorted = [self.module.heappop(heap) for i in range(size)] 176 self.assertEqual(heap_sorted, sorted(data)) 177 178 def test_merge(self): 179 inputs = [] 180 for i in range(random.randrange(25)): 181 row = [] 182 for j in range(random.randrange(100)): 183 tup = random.choice('ABC'), random.randrange(-500, 500) 184 row.append(tup) 185 inputs.append(row) 186 187 for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]: 188 for reverse in [False, True]: 189 seqs = [] 190 for seq in inputs: 191 seqs.append(sorted(seq, key=key, reverse=reverse)) 192 self.assertEqual(sorted(chain(*inputs), key=key, reverse=reverse), 193 list(self.module.merge(*seqs, key=key, reverse=reverse))) 194 self.assertEqual(list(self.module.merge()), []) 195 196 def test_empty_merges(self): 197 # Merging two empty lists (with or without a key) should produce 198 # another empty list. 199 self.assertEqual(list(self.module.merge([], [])), []) 200 self.assertEqual(list(self.module.merge([], [], key=lambda: 6)), []) 201 202 def test_merge_does_not_suppress_index_error(self): 203 # Issue 19018: Heapq.merge suppresses IndexError from user generator 204 def iterable(): 205 s = list(range(10)) 206 for i in range(20): 207 yield s[i] # IndexError when i > 10 208 with self.assertRaises(IndexError): 209 list(self.module.merge(iterable(), iterable())) 210 211 def test_merge_stability(self): 212 class Int(int): 213 pass 214 inputs = [[], [], [], []] 215 for i in range(20000): 216 stream = random.randrange(4) 217 x = random.randrange(500) 218 obj = Int(x) 219 obj.pair = (x, stream) 220 inputs[stream].append(obj) 221 for stream in inputs: 222 stream.sort() 223 result = [i.pair for i in self.module.merge(*inputs)] 224 self.assertEqual(result, sorted(result)) 225 226 def test_nsmallest(self): 227 data = [(random.randrange(2000), i) for i in range(1000)] 228 for f in (None, lambda x: x[0] * 547 % 2000): 229 for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): 230 self.assertEqual(list(self.module.nsmallest(n, data)), 231 sorted(data)[:n]) 232 self.assertEqual(list(self.module.nsmallest(n, data, key=f)), 233 sorted(data, key=f)[:n]) 234 235 def test_nlargest(self): 236 data = [(random.randrange(2000), i) for i in range(1000)] 237 for f in (None, lambda x: x[0] * 547 % 2000): 238 for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): 239 self.assertEqual(list(self.module.nlargest(n, data)), 240 sorted(data, reverse=True)[:n]) 241 self.assertEqual(list(self.module.nlargest(n, data, key=f)), 242 sorted(data, key=f, reverse=True)[:n]) 243 244 def test_comparison_operator(self): 245 # Issue 3051: Make sure heapq works with both __lt__ 246 # For python 3.0, __le__ alone is not enough 247 def hsort(data, comp): 248 data = [comp(x) for x in data] 249 self.module.heapify(data) 250 return [self.module.heappop(data).x for i in range(len(data))] 251 class LT: 252 def __init__(self, x): 253 self.x = x 254 def __lt__(self, other): 255 return self.x > other.x 256 class LE: 257 def __init__(self, x): 258 self.x = x 259 def __le__(self, other): 260 return self.x >= other.x 261 data = [random.random() for i in range(100)] 262 target = sorted(data, reverse=True) 263 self.assertEqual(hsort(data, LT), target) 264 self.assertRaises(TypeError, data, LE) 265 266 267class TestHeapPython(TestHeap, TestCase): 268 module = py_heapq 269 270 271@skipUnless(c_heapq, 'requires _heapq') 272class TestHeapC(TestHeap, TestCase): 273 module = c_heapq 274 275 276#============================================================================== 277 278class LenOnly: 279 "Dummy sequence class defining __len__ but not __getitem__." 280 def __len__(self): 281 return 10 282 283class CmpErr: 284 "Dummy element that always raises an error during comparison" 285 def __eq__(self, other): 286 raise ZeroDivisionError 287 __ne__ = __lt__ = __le__ = __gt__ = __ge__ = __eq__ 288 289def R(seqn): 290 'Regular generator' 291 for i in seqn: 292 yield i 293 294class G: 295 'Sequence using __getitem__' 296 def __init__(self, seqn): 297 self.seqn = seqn 298 def __getitem__(self, i): 299 return self.seqn[i] 300 301class I: 302 'Sequence using iterator protocol' 303 def __init__(self, seqn): 304 self.seqn = seqn 305 self.i = 0 306 def __iter__(self): 307 return self 308 def __next__(self): 309 if self.i >= len(self.seqn): raise StopIteration 310 v = self.seqn[self.i] 311 self.i += 1 312 return v 313 314class Ig: 315 'Sequence using iterator protocol defined with a generator' 316 def __init__(self, seqn): 317 self.seqn = seqn 318 self.i = 0 319 def __iter__(self): 320 for val in self.seqn: 321 yield val 322 323class X: 324 'Missing __getitem__ and __iter__' 325 def __init__(self, seqn): 326 self.seqn = seqn 327 self.i = 0 328 def __next__(self): 329 if self.i >= len(self.seqn): raise StopIteration 330 v = self.seqn[self.i] 331 self.i += 1 332 return v 333 334class N: 335 'Iterator missing __next__()' 336 def __init__(self, seqn): 337 self.seqn = seqn 338 self.i = 0 339 def __iter__(self): 340 return self 341 342class E: 343 'Test propagation of exceptions' 344 def __init__(self, seqn): 345 self.seqn = seqn 346 self.i = 0 347 def __iter__(self): 348 return self 349 def __next__(self): 350 3 // 0 351 352class S: 353 'Test immediate stop' 354 def __init__(self, seqn): 355 pass 356 def __iter__(self): 357 return self 358 def __next__(self): 359 raise StopIteration 360 361from itertools import chain 362def L(seqn): 363 'Test multiple tiers of iterators' 364 return chain(map(lambda x:x, R(Ig(G(seqn))))) 365 366 367class SideEffectLT: 368 def __init__(self, value, heap): 369 self.value = value 370 self.heap = heap 371 372 def __lt__(self, other): 373 self.heap[:] = [] 374 return self.value < other.value 375 376 377class TestErrorHandling: 378 379 def test_non_sequence(self): 380 for f in (self.module.heapify, self.module.heappop): 381 self.assertRaises((TypeError, AttributeError), f, 10) 382 for f in (self.module.heappush, self.module.heapreplace, 383 self.module.nlargest, self.module.nsmallest): 384 self.assertRaises((TypeError, AttributeError), f, 10, 10) 385 386 def test_len_only(self): 387 for f in (self.module.heapify, self.module.heappop): 388 self.assertRaises((TypeError, AttributeError), f, LenOnly()) 389 for f in (self.module.heappush, self.module.heapreplace): 390 self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10) 391 for f in (self.module.nlargest, self.module.nsmallest): 392 self.assertRaises(TypeError, f, 2, LenOnly()) 393 394 def test_cmp_err(self): 395 seq = [CmpErr(), CmpErr(), CmpErr()] 396 for f in (self.module.heapify, self.module.heappop): 397 self.assertRaises(ZeroDivisionError, f, seq) 398 for f in (self.module.heappush, self.module.heapreplace): 399 self.assertRaises(ZeroDivisionError, f, seq, 10) 400 for f in (self.module.nlargest, self.module.nsmallest): 401 self.assertRaises(ZeroDivisionError, f, 2, seq) 402 403 def test_arg_parsing(self): 404 for f in (self.module.heapify, self.module.heappop, 405 self.module.heappush, self.module.heapreplace, 406 self.module.nlargest, self.module.nsmallest): 407 self.assertRaises((TypeError, AttributeError), f, 10) 408 409 def test_iterable_args(self): 410 for f in (self.module.nlargest, self.module.nsmallest): 411 for s in ("123", "", range(1000), (1, 1.2), range(2000,2200,5)): 412 for g in (G, I, Ig, L, R): 413 self.assertEqual(list(f(2, g(s))), list(f(2,s))) 414 self.assertEqual(list(f(2, S(s))), []) 415 self.assertRaises(TypeError, f, 2, X(s)) 416 self.assertRaises(TypeError, f, 2, N(s)) 417 self.assertRaises(ZeroDivisionError, f, 2, E(s)) 418 419 # Issue #17278: the heap may change size while it's being walked. 420 421 def test_heappush_mutating_heap(self): 422 heap = [] 423 heap.extend(SideEffectLT(i, heap) for i in range(200)) 424 # Python version raises IndexError, C version RuntimeError 425 with self.assertRaises((IndexError, RuntimeError)): 426 self.module.heappush(heap, SideEffectLT(5, heap)) 427 428 def test_heappop_mutating_heap(self): 429 heap = [] 430 heap.extend(SideEffectLT(i, heap) for i in range(200)) 431 # Python version raises IndexError, C version RuntimeError 432 with self.assertRaises((IndexError, RuntimeError)): 433 self.module.heappop(heap) 434 435 def test_comparison_operator_modifiying_heap(self): 436 # See bpo-39421: Strong references need to be taken 437 # when comparing objects as they can alter the heap 438 class EvilClass(int): 439 def __lt__(self, o): 440 heap.clear() 441 return NotImplemented 442 443 heap = [] 444 self.module.heappush(heap, EvilClass(0)) 445 self.assertRaises(IndexError, self.module.heappushpop, heap, 1) 446 447 def test_comparison_operator_modifiying_heap_two_heaps(self): 448 449 class h(int): 450 def __lt__(self, o): 451 list2.clear() 452 return NotImplemented 453 454 class g(int): 455 def __lt__(self, o): 456 list1.clear() 457 return NotImplemented 458 459 list1, list2 = [], [] 460 461 self.module.heappush(list1, h(0)) 462 self.module.heappush(list2, g(0)) 463 464 self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1)) 465 self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1)) 466 467class TestErrorHandlingPython(TestErrorHandling, TestCase): 468 module = py_heapq 469 470@skipUnless(c_heapq, 'requires _heapq') 471class TestErrorHandlingC(TestErrorHandling, TestCase): 472 module = c_heapq 473 474 475if __name__ == "__main__": 476 unittest.main() 477