1"""Unittests for heapq.""" 2 3import random 4import unittest 5 6from test import support 7from unittest import TestCase, skipUnless 8from operator import itemgetter 9 10py_heapq = support.import_fresh_module('heapq', blocked=['_heapq']) 11c_heapq = support.import_fresh_module('heapq', fresh=['_heapq']) 12 13# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when 14# _heapq is imported, so check them there 15func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace', 16 '_heappop_max', '_heapreplace_max', '_heapify_max'] 17 18class TestModules(TestCase): 19 def test_py_functions(self): 20 for fname in func_names: 21 self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq') 22 23 @skipUnless(c_heapq, 'requires _heapq') 24 def test_c_functions(self): 25 for fname in func_names: 26 self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq') 27 28 29class TestHeap: 30 31 def test_push_pop(self): 32 # 1) Push 256 random numbers and pop them off, verifying all's OK. 33 heap = [] 34 data = [] 35 self.check_invariant(heap) 36 for i in range(256): 37 item = random.random() 38 data.append(item) 39 self.module.heappush(heap, item) 40 self.check_invariant(heap) 41 results = [] 42 while heap: 43 item = self.module.heappop(heap) 44 self.check_invariant(heap) 45 results.append(item) 46 data_sorted = data[:] 47 data_sorted.sort() 48 self.assertEqual(data_sorted, results) 49 # 2) Check that the invariant holds for a sorted array 50 self.check_invariant(results) 51 52 self.assertRaises(TypeError, self.module.heappush, []) 53 try: 54 self.assertRaises(TypeError, self.module.heappush, None, None) 55 self.assertRaises(TypeError, self.module.heappop, None) 56 except AttributeError: 57 pass 58 59 def check_invariant(self, heap): 60 # Check the heap invariant. 61 for pos, item in enumerate(heap): 62 if pos: # pos 0 has no parent 63 parentpos = (pos-1) >> 1 64 self.assertTrue(heap[parentpos] <= item) 65 66 def test_heapify(self): 67 for size in list(range(30)) + [20000]: 68 heap = [random.random() for dummy in range(size)] 69 self.module.heapify(heap) 70 self.check_invariant(heap) 71 72 self.assertRaises(TypeError, self.module.heapify, None) 73 74 def test_naive_nbest(self): 75 data = [random.randrange(2000) for i in range(1000)] 76 heap = [] 77 for item in data: 78 self.module.heappush(heap, item) 79 if len(heap) > 10: 80 self.module.heappop(heap) 81 heap.sort() 82 self.assertEqual(heap, sorted(data)[-10:]) 83 84 def heapiter(self, heap): 85 # An iterator returning a heap's elements, smallest-first. 86 try: 87 while 1: 88 yield self.module.heappop(heap) 89 except IndexError: 90 pass 91 92 def test_nbest(self): 93 # Less-naive "N-best" algorithm, much faster (if len(data) is big 94 # enough <wink>) than sorting all of data. However, if we had a max 95 # heap instead of a min heap, it could go faster still via 96 # heapify'ing all of data (linear time), then doing 10 heappops 97 # (10 log-time steps). 98 data = [random.randrange(2000) for i in range(1000)] 99 heap = data[:10] 100 self.module.heapify(heap) 101 for item in data[10:]: 102 if item > heap[0]: # this gets rarer the longer we run 103 self.module.heapreplace(heap, item) 104 self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:]) 105 106 self.assertRaises(TypeError, self.module.heapreplace, None) 107 self.assertRaises(TypeError, self.module.heapreplace, None, None) 108 self.assertRaises(IndexError, self.module.heapreplace, [], None) 109 110 def test_nbest_with_pushpop(self): 111 data = [random.randrange(2000) for i in range(1000)] 112 heap = data[:10] 113 self.module.heapify(heap) 114 for item in data[10:]: 115 self.module.heappushpop(heap, item) 116 self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:]) 117 self.assertEqual(self.module.heappushpop([], 'x'), 'x') 118 119 def test_heappushpop(self): 120 h = [] 121 x = self.module.heappushpop(h, 10) 122 self.assertEqual((h, x), ([], 10)) 123 124 h = [10] 125 x = self.module.heappushpop(h, 10.0) 126 self.assertEqual((h, x), ([10], 10.0)) 127 self.assertEqual(type(h[0]), int) 128 self.assertEqual(type(x), float) 129 130 h = [10]; 131 x = self.module.heappushpop(h, 9) 132 self.assertEqual((h, x), ([10], 9)) 133 134 h = [10]; 135 x = self.module.heappushpop(h, 11) 136 self.assertEqual((h, x), ([11], 10)) 137 138 def test_heapsort(self): 139 # Exercise everything with repeated heapsort checks 140 for trial in range(100): 141 size = random.randrange(50) 142 data = [random.randrange(25) for i in range(size)] 143 if trial & 1: # Half of the time, use heapify 144 heap = data[:] 145 self.module.heapify(heap) 146 else: # The rest of the time, use heappush 147 heap = [] 148 for item in data: 149 self.module.heappush(heap, item) 150 heap_sorted = [self.module.heappop(heap) for i in range(size)] 151 self.assertEqual(heap_sorted, sorted(data)) 152 153 def test_merge(self): 154 inputs = [] 155 for i in range(random.randrange(25)): 156 row = [] 157 for j in range(random.randrange(100)): 158 tup = random.choice('ABC'), random.randrange(-500, 500) 159 row.append(tup) 160 inputs.append(row) 161 162 for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]: 163 for reverse in [False, True]: 164 seqs = [] 165 for seq in inputs: 166 seqs.append(sorted(seq, key=key, reverse=reverse)) 167 self.assertEqual(sorted(chain(*inputs), key=key, reverse=reverse), 168 list(self.module.merge(*seqs, key=key, reverse=reverse))) 169 self.assertEqual(list(self.module.merge()), []) 170 171 def test_merge_does_not_suppress_index_error(self): 172 # Issue 19018: Heapq.merge suppresses IndexError from user generator 173 def iterable(): 174 s = list(range(10)) 175 for i in range(20): 176 yield s[i] # IndexError when i > 10 177 with self.assertRaises(IndexError): 178 list(self.module.merge(iterable(), iterable())) 179 180 def test_merge_stability(self): 181 class Int(int): 182 pass 183 inputs = [[], [], [], []] 184 for i in range(20000): 185 stream = random.randrange(4) 186 x = random.randrange(500) 187 obj = Int(x) 188 obj.pair = (x, stream) 189 inputs[stream].append(obj) 190 for stream in inputs: 191 stream.sort() 192 result = [i.pair for i in self.module.merge(*inputs)] 193 self.assertEqual(result, sorted(result)) 194 195 def test_nsmallest(self): 196 data = [(random.randrange(2000), i) for i in range(1000)] 197 for f in (None, lambda x: x[0] * 547 % 2000): 198 for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): 199 self.assertEqual(list(self.module.nsmallest(n, data)), 200 sorted(data)[:n]) 201 self.assertEqual(list(self.module.nsmallest(n, data, key=f)), 202 sorted(data, key=f)[:n]) 203 204 def test_nlargest(self): 205 data = [(random.randrange(2000), i) for i in range(1000)] 206 for f in (None, lambda x: x[0] * 547 % 2000): 207 for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): 208 self.assertEqual(list(self.module.nlargest(n, data)), 209 sorted(data, reverse=True)[:n]) 210 self.assertEqual(list(self.module.nlargest(n, data, key=f)), 211 sorted(data, key=f, reverse=True)[:n]) 212 213 def test_comparison_operator(self): 214 # Issue 3051: Make sure heapq works with both __lt__ 215 # For python 3.0, __le__ alone is not enough 216 def hsort(data, comp): 217 data = [comp(x) for x in data] 218 self.module.heapify(data) 219 return [self.module.heappop(data).x for i in range(len(data))] 220 class LT: 221 def __init__(self, x): 222 self.x = x 223 def __lt__(self, other): 224 return self.x > other.x 225 class LE: 226 def __init__(self, x): 227 self.x = x 228 def __le__(self, other): 229 return self.x >= other.x 230 data = [random.random() for i in range(100)] 231 target = sorted(data, reverse=True) 232 self.assertEqual(hsort(data, LT), target) 233 self.assertRaises(TypeError, data, LE) 234 235 236class TestHeapPython(TestHeap, TestCase): 237 module = py_heapq 238 239 240@skipUnless(c_heapq, 'requires _heapq') 241class TestHeapC(TestHeap, TestCase): 242 module = c_heapq 243 244 245#============================================================================== 246 247class LenOnly: 248 "Dummy sequence class defining __len__ but not __getitem__." 249 def __len__(self): 250 return 10 251 252class GetOnly: 253 "Dummy sequence class defining __getitem__ but not __len__." 254 def __getitem__(self, ndx): 255 return 10 256 257class CmpErr: 258 "Dummy element that always raises an error during comparison" 259 def __eq__(self, other): 260 raise ZeroDivisionError 261 __ne__ = __lt__ = __le__ = __gt__ = __ge__ = __eq__ 262 263def R(seqn): 264 'Regular generator' 265 for i in seqn: 266 yield i 267 268class G: 269 'Sequence using __getitem__' 270 def __init__(self, seqn): 271 self.seqn = seqn 272 def __getitem__(self, i): 273 return self.seqn[i] 274 275class I: 276 'Sequence using iterator protocol' 277 def __init__(self, seqn): 278 self.seqn = seqn 279 self.i = 0 280 def __iter__(self): 281 return self 282 def __next__(self): 283 if self.i >= len(self.seqn): raise StopIteration 284 v = self.seqn[self.i] 285 self.i += 1 286 return v 287 288class Ig: 289 'Sequence using iterator protocol defined with a generator' 290 def __init__(self, seqn): 291 self.seqn = seqn 292 self.i = 0 293 def __iter__(self): 294 for val in self.seqn: 295 yield val 296 297class X: 298 'Missing __getitem__ and __iter__' 299 def __init__(self, seqn): 300 self.seqn = seqn 301 self.i = 0 302 def __next__(self): 303 if self.i >= len(self.seqn): raise StopIteration 304 v = self.seqn[self.i] 305 self.i += 1 306 return v 307 308class N: 309 'Iterator missing __next__()' 310 def __init__(self, seqn): 311 self.seqn = seqn 312 self.i = 0 313 def __iter__(self): 314 return self 315 316class E: 317 'Test propagation of exceptions' 318 def __init__(self, seqn): 319 self.seqn = seqn 320 self.i = 0 321 def __iter__(self): 322 return self 323 def __next__(self): 324 3 // 0 325 326class S: 327 'Test immediate stop' 328 def __init__(self, seqn): 329 pass 330 def __iter__(self): 331 return self 332 def __next__(self): 333 raise StopIteration 334 335from itertools import chain 336def L(seqn): 337 'Test multiple tiers of iterators' 338 return chain(map(lambda x:x, R(Ig(G(seqn))))) 339 340 341class SideEffectLT: 342 def __init__(self, value, heap): 343 self.value = value 344 self.heap = heap 345 346 def __lt__(self, other): 347 self.heap[:] = [] 348 return self.value < other.value 349 350 351class TestErrorHandling: 352 353 def test_non_sequence(self): 354 for f in (self.module.heapify, self.module.heappop): 355 self.assertRaises((TypeError, AttributeError), f, 10) 356 for f in (self.module.heappush, self.module.heapreplace, 357 self.module.nlargest, self.module.nsmallest): 358 self.assertRaises((TypeError, AttributeError), f, 10, 10) 359 360 def test_len_only(self): 361 for f in (self.module.heapify, self.module.heappop): 362 self.assertRaises((TypeError, AttributeError), f, LenOnly()) 363 for f in (self.module.heappush, self.module.heapreplace): 364 self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10) 365 for f in (self.module.nlargest, self.module.nsmallest): 366 self.assertRaises(TypeError, f, 2, LenOnly()) 367 368 def test_get_only(self): 369 for f in (self.module.heapify, self.module.heappop): 370 self.assertRaises(TypeError, f, GetOnly()) 371 for f in (self.module.heappush, self.module.heapreplace): 372 self.assertRaises(TypeError, f, GetOnly(), 10) 373 for f in (self.module.nlargest, self.module.nsmallest): 374 self.assertRaises(TypeError, f, 2, GetOnly()) 375 376 def test_get_only(self): 377 seq = [CmpErr(), CmpErr(), CmpErr()] 378 for f in (self.module.heapify, self.module.heappop): 379 self.assertRaises(ZeroDivisionError, f, seq) 380 for f in (self.module.heappush, self.module.heapreplace): 381 self.assertRaises(ZeroDivisionError, f, seq, 10) 382 for f in (self.module.nlargest, self.module.nsmallest): 383 self.assertRaises(ZeroDivisionError, f, 2, seq) 384 385 def test_arg_parsing(self): 386 for f in (self.module.heapify, self.module.heappop, 387 self.module.heappush, self.module.heapreplace, 388 self.module.nlargest, self.module.nsmallest): 389 self.assertRaises((TypeError, AttributeError), f, 10) 390 391 def test_iterable_args(self): 392 for f in (self.module.nlargest, self.module.nsmallest): 393 for s in ("123", "", range(1000), (1, 1.2), range(2000,2200,5)): 394 for g in (G, I, Ig, L, R): 395 self.assertEqual(list(f(2, g(s))), list(f(2,s))) 396 self.assertEqual(list(f(2, S(s))), []) 397 self.assertRaises(TypeError, f, 2, X(s)) 398 self.assertRaises(TypeError, f, 2, N(s)) 399 self.assertRaises(ZeroDivisionError, f, 2, E(s)) 400 401 # Issue #17278: the heap may change size while it's being walked. 402 403 def test_heappush_mutating_heap(self): 404 heap = [] 405 heap.extend(SideEffectLT(i, heap) for i in range(200)) 406 # Python version raises IndexError, C version RuntimeError 407 with self.assertRaises((IndexError, RuntimeError)): 408 self.module.heappush(heap, SideEffectLT(5, heap)) 409 410 def test_heappop_mutating_heap(self): 411 heap = [] 412 heap.extend(SideEffectLT(i, heap) for i in range(200)) 413 # Python version raises IndexError, C version RuntimeError 414 with self.assertRaises((IndexError, RuntimeError)): 415 self.module.heappop(heap) 416 417 418class TestErrorHandlingPython(TestErrorHandling, TestCase): 419 module = py_heapq 420 421@skipUnless(c_heapq, 'requires _heapq') 422class TestErrorHandlingC(TestErrorHandling, TestCase): 423 module = c_heapq 424 425 426if __name__ == "__main__": 427 unittest.main() 428