1import unittest
2from test import test_support
3from weakref import proxy, ref, WeakSet
4import operator
5import copy
6import string
7import os
8from random import randrange, shuffle
9import sys
10import warnings
11import collections
12import gc
13import contextlib
14
15
16class Foo:
17    pass
18
19class SomeClass(object):
20    def __init__(self, value):
21        self.value = value
22    def __eq__(self, other):
23        if type(other) != type(self):
24            return False
25        return other.value == self.value
26
27    def __ne__(self, other):
28        return not self.__eq__(other)
29
30    def __hash__(self):
31        return hash((SomeClass, self.value))
32
33class RefCycle(object):
34    def __init__(self):
35        self.cycle = self
36
37class TestWeakSet(unittest.TestCase):
38
39    def setUp(self):
40        # need to keep references to them
41        self.items = [SomeClass(c) for c in ('a', 'b', 'c')]
42        self.items2 = [SomeClass(c) for c in ('x', 'y', 'z')]
43        self.letters = [SomeClass(c) for c in string.ascii_letters]
44        self.ab_items = [SomeClass(c) for c in 'ab']
45        self.abcde_items = [SomeClass(c) for c in 'abcde']
46        self.def_items = [SomeClass(c) for c in 'def']
47        self.ab_weakset = WeakSet(self.ab_items)
48        self.abcde_weakset = WeakSet(self.abcde_items)
49        self.def_weakset = WeakSet(self.def_items)
50        self.s = WeakSet(self.items)
51        self.d = dict.fromkeys(self.items)
52        self.obj = SomeClass('F')
53        self.fs = WeakSet([self.obj])
54
55    def test_methods(self):
56        weaksetmethods = dir(WeakSet)
57        for method in dir(set):
58            if method == 'test_c_api' or method.startswith('_'):
59                continue
60            self.assertIn(method, weaksetmethods,
61                         "WeakSet missing method " + method)
62
63    def test_new_or_init(self):
64        self.assertRaises(TypeError, WeakSet, [], 2)
65
66    def test_len(self):
67        self.assertEqual(len(self.s), len(self.d))
68        self.assertEqual(len(self.fs), 1)
69        del self.obj
70        self.assertEqual(len(self.fs), 0)
71
72    def test_contains(self):
73        for c in self.letters:
74            self.assertEqual(c in self.s, c in self.d)
75        # 1 is not weakref'able, but that TypeError is caught by __contains__
76        self.assertNotIn(1, self.s)
77        self.assertIn(self.obj, self.fs)
78        del self.obj
79        self.assertNotIn(SomeClass('F'), self.fs)
80
81    def test_union(self):
82        u = self.s.union(self.items2)
83        for c in self.letters:
84            self.assertEqual(c in u, c in self.d or c in self.items2)
85        self.assertEqual(self.s, WeakSet(self.items))
86        self.assertEqual(type(u), WeakSet)
87        self.assertRaises(TypeError, self.s.union, [[]])
88        for C in set, frozenset, dict.fromkeys, list, tuple:
89            x = WeakSet(self.items + self.items2)
90            c = C(self.items2)
91            self.assertEqual(self.s.union(c), x)
92            del c
93        self.assertEqual(len(u), len(self.items) + len(self.items2))
94        self.items2.pop()
95        gc.collect()
96        self.assertEqual(len(u), len(self.items) + len(self.items2))
97
98    def test_or(self):
99        i = self.s.union(self.items2)
100        self.assertEqual(self.s | set(self.items2), i)
101        self.assertEqual(self.s | frozenset(self.items2), i)
102
103    def test_intersection(self):
104        s = WeakSet(self.letters)
105        i = s.intersection(self.items2)
106        for c in self.letters:
107            self.assertEqual(c in i, c in self.items2 and c in self.letters)
108        self.assertEqual(s, WeakSet(self.letters))
109        self.assertEqual(type(i), WeakSet)
110        for C in set, frozenset, dict.fromkeys, list, tuple:
111            x = WeakSet([])
112            self.assertEqual(i.intersection(C(self.items)), x)
113        self.assertEqual(len(i), len(self.items2))
114        self.items2.pop()
115        gc.collect()
116        self.assertEqual(len(i), len(self.items2))
117
118    def test_isdisjoint(self):
119        self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
120        self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters)))
121
122    def test_and(self):
123        i = self.s.intersection(self.items2)
124        self.assertEqual(self.s & set(self.items2), i)
125        self.assertEqual(self.s & frozenset(self.items2), i)
126
127    def test_difference(self):
128        i = self.s.difference(self.items2)
129        for c in self.letters:
130            self.assertEqual(c in i, c in self.d and c not in self.items2)
131        self.assertEqual(self.s, WeakSet(self.items))
132        self.assertEqual(type(i), WeakSet)
133        self.assertRaises(TypeError, self.s.difference, [[]])
134
135    def test_sub(self):
136        i = self.s.difference(self.items2)
137        self.assertEqual(self.s - set(self.items2), i)
138        self.assertEqual(self.s - frozenset(self.items2), i)
139
140    def test_symmetric_difference(self):
141        i = self.s.symmetric_difference(self.items2)
142        for c in self.letters:
143            self.assertEqual(c in i, (c in self.d) ^ (c in self.items2))
144        self.assertEqual(self.s, WeakSet(self.items))
145        self.assertEqual(type(i), WeakSet)
146        self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
147        self.assertEqual(len(i), len(self.items) + len(self.items2))
148        self.items2.pop()
149        gc.collect()
150        self.assertEqual(len(i), len(self.items) + len(self.items2))
151
152    def test_xor(self):
153        i = self.s.symmetric_difference(self.items2)
154        self.assertEqual(self.s ^ set(self.items2), i)
155        self.assertEqual(self.s ^ frozenset(self.items2), i)
156
157    def test_sub_and_super(self):
158        self.assertTrue(self.ab_weakset <= self.abcde_weakset)
159        self.assertTrue(self.abcde_weakset <= self.abcde_weakset)
160        self.assertTrue(self.abcde_weakset >= self.ab_weakset)
161        self.assertFalse(self.abcde_weakset <= self.def_weakset)
162        self.assertFalse(self.abcde_weakset >= self.def_weakset)
163        self.assertTrue(set('a').issubset('abc'))
164        self.assertTrue(set('abc').issuperset('a'))
165        self.assertFalse(set('a').issubset('cbs'))
166        self.assertFalse(set('cbs').issuperset('a'))
167
168    def test_lt(self):
169        self.assertTrue(self.ab_weakset < self.abcde_weakset)
170        self.assertFalse(self.abcde_weakset < self.def_weakset)
171        self.assertFalse(self.ab_weakset < self.ab_weakset)
172        self.assertFalse(WeakSet() < WeakSet())
173
174    def test_gt(self):
175        self.assertTrue(self.abcde_weakset > self.ab_weakset)
176        self.assertFalse(self.abcde_weakset > self.def_weakset)
177        self.assertFalse(self.ab_weakset > self.ab_weakset)
178        self.assertFalse(WeakSet() > WeakSet())
179
180    def test_gc(self):
181        # Create a nest of cycles to exercise overall ref count check
182        s = WeakSet(Foo() for i in range(1000))
183        for elem in s:
184            elem.cycle = s
185            elem.sub = elem
186            elem.set = WeakSet([elem])
187
188    def test_subclass_with_custom_hash(self):
189        # Bug #1257731
190        class H(WeakSet):
191            def __hash__(self):
192                return int(id(self) & 0x7fffffff)
193        s=H()
194        f=set()
195        f.add(s)
196        self.assertIn(s, f)
197        f.remove(s)
198        f.add(s)
199        f.discard(s)
200
201    def test_init(self):
202        s = WeakSet()
203        s.__init__(self.items)
204        self.assertEqual(s, self.s)
205        s.__init__(self.items2)
206        self.assertEqual(s, WeakSet(self.items2))
207        self.assertRaises(TypeError, s.__init__, s, 2);
208        self.assertRaises(TypeError, s.__init__, 1);
209
210    def test_constructor_identity(self):
211        s = WeakSet(self.items)
212        t = WeakSet(s)
213        self.assertNotEqual(id(s), id(t))
214
215    def test_hash(self):
216        self.assertRaises(TypeError, hash, self.s)
217
218    def test_clear(self):
219        self.s.clear()
220        self.assertEqual(self.s, WeakSet([]))
221        self.assertEqual(len(self.s), 0)
222
223    def test_copy(self):
224        dup = self.s.copy()
225        self.assertEqual(self.s, dup)
226        self.assertNotEqual(id(self.s), id(dup))
227
228    def test_add(self):
229        x = SomeClass('Q')
230        self.s.add(x)
231        self.assertIn(x, self.s)
232        dup = self.s.copy()
233        self.s.add(x)
234        self.assertEqual(self.s, dup)
235        self.assertRaises(TypeError, self.s.add, [])
236        self.fs.add(Foo())
237        self.assertTrue(len(self.fs) == 1)
238        self.fs.add(self.obj)
239        self.assertTrue(len(self.fs) == 1)
240
241    def test_remove(self):
242        x = SomeClass('a')
243        self.s.remove(x)
244        self.assertNotIn(x, self.s)
245        self.assertRaises(KeyError, self.s.remove, x)
246        self.assertRaises(TypeError, self.s.remove, [])
247
248    def test_discard(self):
249        a, q = SomeClass('a'), SomeClass('Q')
250        self.s.discard(a)
251        self.assertNotIn(a, self.s)
252        self.s.discard(q)
253        self.assertRaises(TypeError, self.s.discard, [])
254
255    def test_pop(self):
256        for i in range(len(self.s)):
257            elem = self.s.pop()
258            self.assertNotIn(elem, self.s)
259        self.assertRaises(KeyError, self.s.pop)
260
261    def test_update(self):
262        retval = self.s.update(self.items2)
263        self.assertEqual(retval, None)
264        for c in (self.items + self.items2):
265            self.assertIn(c, self.s)
266        self.assertRaises(TypeError, self.s.update, [[]])
267
268    def test_update_set(self):
269        self.s.update(set(self.items2))
270        for c in (self.items + self.items2):
271            self.assertIn(c, self.s)
272
273    def test_ior(self):
274        self.s |= set(self.items2)
275        for c in (self.items + self.items2):
276            self.assertIn(c, self.s)
277
278    def test_intersection_update(self):
279        retval = self.s.intersection_update(self.items2)
280        self.assertEqual(retval, None)
281        for c in (self.items + self.items2):
282            if c in self.items2 and c in self.items:
283                self.assertIn(c, self.s)
284            else:
285                self.assertNotIn(c, self.s)
286        self.assertRaises(TypeError, self.s.intersection_update, [[]])
287
288    def test_iand(self):
289        self.s &= set(self.items2)
290        for c in (self.items + self.items2):
291            if c in self.items2 and c in self.items:
292                self.assertIn(c, self.s)
293            else:
294                self.assertNotIn(c, self.s)
295
296    def test_difference_update(self):
297        retval = self.s.difference_update(self.items2)
298        self.assertEqual(retval, None)
299        for c in (self.items + self.items2):
300            if c in self.items and c not in self.items2:
301                self.assertIn(c, self.s)
302            else:
303                self.assertNotIn(c, self.s)
304        self.assertRaises(TypeError, self.s.difference_update, [[]])
305        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
306
307    def test_isub(self):
308        self.s -= set(self.items2)
309        for c in (self.items + self.items2):
310            if c in self.items and c not in self.items2:
311                self.assertIn(c, self.s)
312            else:
313                self.assertNotIn(c, self.s)
314
315    def test_symmetric_difference_update(self):
316        retval = self.s.symmetric_difference_update(self.items2)
317        self.assertEqual(retval, None)
318        for c in (self.items + self.items2):
319            if (c in self.items) ^ (c in self.items2):
320                self.assertIn(c, self.s)
321            else:
322                self.assertNotIn(c, self.s)
323        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
324
325    def test_ixor(self):
326        self.s ^= set(self.items2)
327        for c in (self.items + self.items2):
328            if (c in self.items) ^ (c in self.items2):
329                self.assertIn(c, self.s)
330            else:
331                self.assertNotIn(c, self.s)
332
333    def test_inplace_on_self(self):
334        t = self.s.copy()
335        t |= t
336        self.assertEqual(t, self.s)
337        t &= t
338        self.assertEqual(t, self.s)
339        t -= t
340        self.assertEqual(t, WeakSet())
341        t = self.s.copy()
342        t ^= t
343        self.assertEqual(t, WeakSet())
344
345    def test_eq(self):
346        # issue 5964
347        self.assertTrue(self.s == self.s)
348        self.assertTrue(self.s == WeakSet(self.items))
349        self.assertFalse(self.s == set(self.items))
350        self.assertFalse(self.s == list(self.items))
351        self.assertFalse(self.s == tuple(self.items))
352        self.assertFalse(self.s == 1)
353
354    def test_ne(self):
355        self.assertTrue(self.s != set(self.items))
356        s1 = WeakSet()
357        s2 = WeakSet()
358        self.assertFalse(s1 != s2)
359
360    def test_weak_destroy_while_iterating(self):
361        # Issue #7105: iterators shouldn't crash when a key is implicitly removed
362        # Create new items to be sure no-one else holds a reference
363        items = [SomeClass(c) for c in ('a', 'b', 'c')]
364        s = WeakSet(items)
365        it = iter(s)
366        next(it)             # Trigger internal iteration
367        # Destroy an item
368        del items[-1]
369        gc.collect()    # just in case
370        # We have removed either the first consumed items, or another one
371        self.assertIn(len(list(it)), [len(items), len(items) - 1])
372        del it
373        # The removal has been committed
374        self.assertEqual(len(s), len(items))
375
376    def test_weak_destroy_and_mutate_while_iterating(self):
377        # Issue #7105: iterators shouldn't crash when a key is implicitly removed
378        items = [SomeClass(c) for c in string.ascii_letters]
379        s = WeakSet(items)
380        @contextlib.contextmanager
381        def testcontext():
382            try:
383                it = iter(s)
384                next(it)
385                # Schedule an item for removal and recreate it
386                u = SomeClass(str(items.pop()))
387                gc.collect()      # just in case
388                yield u
389            finally:
390                it = None           # should commit all removals
391
392        with testcontext() as u:
393            self.assertNotIn(u, s)
394        with testcontext() as u:
395            self.assertRaises(KeyError, s.remove, u)
396        self.assertNotIn(u, s)
397        with testcontext() as u:
398            s.add(u)
399        self.assertIn(u, s)
400        t = s.copy()
401        with testcontext() as u:
402            s.update(t)
403        self.assertEqual(len(s), len(t))
404        with testcontext() as u:
405            s.clear()
406        self.assertEqual(len(s), 0)
407
408    def test_len_cycles(self):
409        N = 20
410        items = [RefCycle() for i in range(N)]
411        s = WeakSet(items)
412        del items
413        it = iter(s)
414        try:
415            next(it)
416        except StopIteration:
417            pass
418        gc.collect()
419        n1 = len(s)
420        del it
421        gc.collect()
422        n2 = len(s)
423        # one item may be kept alive inside the iterator
424        self.assertIn(n1, (0, 1))
425        self.assertEqual(n2, 0)
426
427    def test_len_race(self):
428        # Extended sanity checks for len() in the face of cyclic collection
429        self.addCleanup(gc.set_threshold, *gc.get_threshold())
430        for th in range(1, 100):
431            N = 20
432            gc.collect(0)
433            gc.set_threshold(th, th, th)
434            items = [RefCycle() for i in range(N)]
435            s = WeakSet(items)
436            del items
437            # All items will be collected at next garbage collection pass
438            it = iter(s)
439            try:
440                next(it)
441            except StopIteration:
442                pass
443            n1 = len(s)
444            del it
445            n2 = len(s)
446            self.assertGreaterEqual(n1, 0)
447            self.assertLessEqual(n1, N)
448            self.assertGreaterEqual(n2, 0)
449            self.assertLessEqual(n2, n1)
450
451
452def test_main(verbose=None):
453    test_support.run_unittest(TestWeakSet)
454
455if __name__ == "__main__":
456    test_main(verbose=True)
457