1import abc
2import builtins
3import collections
4import collections.abc
5import copy
6from itertools import permutations
7import pickle
8from random import choice
9import sys
10from test import support
11import threading
12import time
13import typing
14import unittest
15import unittest.mock
16import os
17import weakref
18import gc
19from weakref import proxy
20import contextlib
21
22from test.support.script_helper import assert_python_ok
23
24import functools
25
26py_functools = support.import_fresh_module('functools', blocked=['_functools'])
27c_functools = support.import_fresh_module('functools', fresh=['_functools'])
28
29decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
30
31@contextlib.contextmanager
32def replaced_module(name, replacement):
33    original_module = sys.modules[name]
34    sys.modules[name] = replacement
35    try:
36        yield
37    finally:
38        sys.modules[name] = original_module
39
40def capture(*args, **kw):
41    """capture all positional and keyword arguments"""
42    return args, kw
43
44
45def signature(part):
46    """ return the signature of a partial object """
47    return (part.func, part.args, part.keywords, part.__dict__)
48
49class MyTuple(tuple):
50    pass
51
52class BadTuple(tuple):
53    def __add__(self, other):
54        return list(self) + list(other)
55
56class MyDict(dict):
57    pass
58
59
60class TestPartial:
61
62    def test_basic_examples(self):
63        p = self.partial(capture, 1, 2, a=10, b=20)
64        self.assertTrue(callable(p))
65        self.assertEqual(p(3, 4, b=30, c=40),
66                         ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
67        p = self.partial(map, lambda x: x*10)
68        self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
69
70    def test_attributes(self):
71        p = self.partial(capture, 1, 2, a=10, b=20)
72        # attributes should be readable
73        self.assertEqual(p.func, capture)
74        self.assertEqual(p.args, (1, 2))
75        self.assertEqual(p.keywords, dict(a=10, b=20))
76
77    def test_argument_checking(self):
78        self.assertRaises(TypeError, self.partial)     # need at least a func arg
79        try:
80            self.partial(2)()
81        except TypeError:
82            pass
83        else:
84            self.fail('First arg not checked for callability')
85
86    def test_protection_of_callers_dict_argument(self):
87        # a caller's dictionary should not be altered by partial
88        def func(a=10, b=20):
89            return a
90        d = {'a':3}
91        p = self.partial(func, a=5)
92        self.assertEqual(p(**d), 3)
93        self.assertEqual(d, {'a':3})
94        p(b=7)
95        self.assertEqual(d, {'a':3})
96
97    def test_kwargs_copy(self):
98        # Issue #29532: Altering a kwarg dictionary passed to a constructor
99        # should not affect a partial object after creation
100        d = {'a': 3}
101        p = self.partial(capture, **d)
102        self.assertEqual(p(), ((), {'a': 3}))
103        d['a'] = 5
104        self.assertEqual(p(), ((), {'a': 3}))
105
106    def test_arg_combinations(self):
107        # exercise special code paths for zero args in either partial
108        # object or the caller
109        p = self.partial(capture)
110        self.assertEqual(p(), ((), {}))
111        self.assertEqual(p(1,2), ((1,2), {}))
112        p = self.partial(capture, 1, 2)
113        self.assertEqual(p(), ((1,2), {}))
114        self.assertEqual(p(3,4), ((1,2,3,4), {}))
115
116    def test_kw_combinations(self):
117        # exercise special code paths for no keyword args in
118        # either the partial object or the caller
119        p = self.partial(capture)
120        self.assertEqual(p.keywords, {})
121        self.assertEqual(p(), ((), {}))
122        self.assertEqual(p(a=1), ((), {'a':1}))
123        p = self.partial(capture, a=1)
124        self.assertEqual(p.keywords, {'a':1})
125        self.assertEqual(p(), ((), {'a':1}))
126        self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
127        # keyword args in the call override those in the partial object
128        self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
129
130    def test_positional(self):
131        # make sure positional arguments are captured correctly
132        for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
133            p = self.partial(capture, *args)
134            expected = args + ('x',)
135            got, empty = p('x')
136            self.assertTrue(expected == got and empty == {})
137
138    def test_keyword(self):
139        # make sure keyword arguments are captured correctly
140        for a in ['a', 0, None, 3.5]:
141            p = self.partial(capture, a=a)
142            expected = {'a':a,'x':None}
143            empty, got = p(x=None)
144            self.assertTrue(expected == got and empty == ())
145
146    def test_no_side_effects(self):
147        # make sure there are no side effects that affect subsequent calls
148        p = self.partial(capture, 0, a=1)
149        args1, kw1 = p(1, b=2)
150        self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
151        args2, kw2 = p()
152        self.assertTrue(args2 == (0,) and kw2 == {'a':1})
153
154    def test_error_propagation(self):
155        def f(x, y):
156            x / y
157        self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
158        self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
159        self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
160        self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
161
162    def test_weakref(self):
163        f = self.partial(int, base=16)
164        p = proxy(f)
165        self.assertEqual(f.func, p.func)
166        f = None
167        self.assertRaises(ReferenceError, getattr, p, 'func')
168
169    def test_with_bound_and_unbound_methods(self):
170        data = list(map(str, range(10)))
171        join = self.partial(str.join, '')
172        self.assertEqual(join(data), '0123456789')
173        join = self.partial(''.join)
174        self.assertEqual(join(data), '0123456789')
175
176    def test_nested_optimization(self):
177        partial = self.partial
178        inner = partial(signature, 'asdf')
179        nested = partial(inner, bar=True)
180        flat = partial(signature, 'asdf', bar=True)
181        self.assertEqual(signature(nested), signature(flat))
182
183    def test_nested_partial_with_attribute(self):
184        # see issue 25137
185        partial = self.partial
186
187        def foo(bar):
188            return bar
189
190        p = partial(foo, 'first')
191        p2 = partial(p, 'second')
192        p2.new_attr = 'spam'
193        self.assertEqual(p2.new_attr, 'spam')
194
195    def test_repr(self):
196        args = (object(), object())
197        args_repr = ', '.join(repr(a) for a in args)
198        kwargs = {'a': object(), 'b': object()}
199        kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
200                        'b={b!r}, a={a!r}'.format_map(kwargs)]
201        if self.partial in (c_functools.partial, py_functools.partial):
202            name = 'functools.partial'
203        else:
204            name = self.partial.__name__
205
206        f = self.partial(capture)
207        self.assertEqual(f'{name}({capture!r})', repr(f))
208
209        f = self.partial(capture, *args)
210        self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
211
212        f = self.partial(capture, **kwargs)
213        self.assertIn(repr(f),
214                      [f'{name}({capture!r}, {kwargs_repr})'
215                       for kwargs_repr in kwargs_reprs])
216
217        f = self.partial(capture, *args, **kwargs)
218        self.assertIn(repr(f),
219                      [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
220                       for kwargs_repr in kwargs_reprs])
221
222    def test_recursive_repr(self):
223        if self.partial in (c_functools.partial, py_functools.partial):
224            name = 'functools.partial'
225        else:
226            name = self.partial.__name__
227
228        f = self.partial(capture)
229        f.__setstate__((f, (), {}, {}))
230        try:
231            self.assertEqual(repr(f), '%s(...)' % (name,))
232        finally:
233            f.__setstate__((capture, (), {}, {}))
234
235        f = self.partial(capture)
236        f.__setstate__((capture, (f,), {}, {}))
237        try:
238            self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
239        finally:
240            f.__setstate__((capture, (), {}, {}))
241
242        f = self.partial(capture)
243        f.__setstate__((capture, (), {'a': f}, {}))
244        try:
245            self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
246        finally:
247            f.__setstate__((capture, (), {}, {}))
248
249    def test_pickle(self):
250        with self.AllowPickle():
251            f = self.partial(signature, ['asdf'], bar=[True])
252            f.attr = []
253            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
254                f_copy = pickle.loads(pickle.dumps(f, proto))
255                self.assertEqual(signature(f_copy), signature(f))
256
257    def test_copy(self):
258        f = self.partial(signature, ['asdf'], bar=[True])
259        f.attr = []
260        f_copy = copy.copy(f)
261        self.assertEqual(signature(f_copy), signature(f))
262        self.assertIs(f_copy.attr, f.attr)
263        self.assertIs(f_copy.args, f.args)
264        self.assertIs(f_copy.keywords, f.keywords)
265
266    def test_deepcopy(self):
267        f = self.partial(signature, ['asdf'], bar=[True])
268        f.attr = []
269        f_copy = copy.deepcopy(f)
270        self.assertEqual(signature(f_copy), signature(f))
271        self.assertIsNot(f_copy.attr, f.attr)
272        self.assertIsNot(f_copy.args, f.args)
273        self.assertIsNot(f_copy.args[0], f.args[0])
274        self.assertIsNot(f_copy.keywords, f.keywords)
275        self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
276
277    def test_setstate(self):
278        f = self.partial(signature)
279        f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
280
281        self.assertEqual(signature(f),
282                         (capture, (1,), dict(a=10), dict(attr=[])))
283        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285        f.__setstate__((capture, (1,), dict(a=10), None))
286
287        self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
288        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
289
290        f.__setstate__((capture, (1,), None, None))
291        #self.assertEqual(signature(f), (capture, (1,), {}, {}))
292        self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
293        self.assertEqual(f(2), ((1, 2), {}))
294        self.assertEqual(f(), ((1,), {}))
295
296        f.__setstate__((capture, (), {}, None))
297        self.assertEqual(signature(f), (capture, (), {}, {}))
298        self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
299        self.assertEqual(f(2), ((2,), {}))
300        self.assertEqual(f(), ((), {}))
301
302    def test_setstate_errors(self):
303        f = self.partial(signature)
304        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
305        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
306        self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
307        self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
308        self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
309        self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
310        self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
311
312    def test_setstate_subclasses(self):
313        f = self.partial(signature)
314        f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
315        s = signature(f)
316        self.assertEqual(s, (capture, (1,), dict(a=10), {}))
317        self.assertIs(type(s[1]), tuple)
318        self.assertIs(type(s[2]), dict)
319        r = f()
320        self.assertEqual(r, ((1,), {'a': 10}))
321        self.assertIs(type(r[0]), tuple)
322        self.assertIs(type(r[1]), dict)
323
324        f.__setstate__((capture, BadTuple((1,)), {}, None))
325        s = signature(f)
326        self.assertEqual(s, (capture, (1,), {}, {}))
327        self.assertIs(type(s[1]), tuple)
328        r = f(2)
329        self.assertEqual(r, ((1, 2), {}))
330        self.assertIs(type(r[0]), tuple)
331
332    def test_recursive_pickle(self):
333        with self.AllowPickle():
334            f = self.partial(capture)
335            f.__setstate__((f, (), {}, {}))
336            try:
337                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
338                    with self.assertRaises(RecursionError):
339                        pickle.dumps(f, proto)
340            finally:
341                f.__setstate__((capture, (), {}, {}))
342
343            f = self.partial(capture)
344            f.__setstate__((capture, (f,), {}, {}))
345            try:
346                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
347                    f_copy = pickle.loads(pickle.dumps(f, proto))
348                    try:
349                        self.assertIs(f_copy.args[0], f_copy)
350                    finally:
351                        f_copy.__setstate__((capture, (), {}, {}))
352            finally:
353                f.__setstate__((capture, (), {}, {}))
354
355            f = self.partial(capture)
356            f.__setstate__((capture, (), {'a': f}, {}))
357            try:
358                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
359                    f_copy = pickle.loads(pickle.dumps(f, proto))
360                    try:
361                        self.assertIs(f_copy.keywords['a'], f_copy)
362                    finally:
363                        f_copy.__setstate__((capture, (), {}, {}))
364            finally:
365                f.__setstate__((capture, (), {}, {}))
366
367    # Issue 6083: Reference counting bug
368    def test_setstate_refcount(self):
369        class BadSequence:
370            def __len__(self):
371                return 4
372            def __getitem__(self, key):
373                if key == 0:
374                    return max
375                elif key == 1:
376                    return tuple(range(1000000))
377                elif key in (2, 3):
378                    return {}
379                raise IndexError
380
381        f = self.partial(object)
382        self.assertRaises(TypeError, f.__setstate__, BadSequence())
383
384@unittest.skipUnless(c_functools, 'requires the C _functools module')
385class TestPartialC(TestPartial, unittest.TestCase):
386    if c_functools:
387        partial = c_functools.partial
388
389    class AllowPickle:
390        def __enter__(self):
391            return self
392        def __exit__(self, type, value, tb):
393            return False
394
395    def test_attributes_unwritable(self):
396        # attributes should not be writable
397        p = self.partial(capture, 1, 2, a=10, b=20)
398        self.assertRaises(AttributeError, setattr, p, 'func', map)
399        self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
400        self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
401
402        p = self.partial(hex)
403        try:
404            del p.__dict__
405        except TypeError:
406            pass
407        else:
408            self.fail('partial object allowed __dict__ to be deleted')
409
410    def test_manually_adding_non_string_keyword(self):
411        p = self.partial(capture)
412        # Adding a non-string/unicode keyword to partial kwargs
413        p.keywords[1234] = 'value'
414        r = repr(p)
415        self.assertIn('1234', r)
416        self.assertIn("'value'", r)
417        with self.assertRaises(TypeError):
418            p()
419
420    def test_keystr_replaces_value(self):
421        p = self.partial(capture)
422
423        class MutatesYourDict(object):
424            def __str__(self):
425                p.keywords[self] = ['sth2']
426                return 'astr'
427
428        # Replacing the value during key formatting should keep the original
429        # value alive (at least long enough).
430        p.keywords[MutatesYourDict()] = ['sth']
431        r = repr(p)
432        self.assertIn('astr', r)
433        self.assertIn("['sth']", r)
434
435
436class TestPartialPy(TestPartial, unittest.TestCase):
437    partial = py_functools.partial
438
439    class AllowPickle:
440        def __init__(self):
441            self._cm = replaced_module("functools", py_functools)
442        def __enter__(self):
443            return self._cm.__enter__()
444        def __exit__(self, type, value, tb):
445            return self._cm.__exit__(type, value, tb)
446
447if c_functools:
448    class CPartialSubclass(c_functools.partial):
449        pass
450
451class PyPartialSubclass(py_functools.partial):
452    pass
453
454@unittest.skipUnless(c_functools, 'requires the C _functools module')
455class TestPartialCSubclass(TestPartialC):
456    if c_functools:
457        partial = CPartialSubclass
458
459    # partial subclasses are not optimized for nested calls
460    test_nested_optimization = None
461
462class TestPartialPySubclass(TestPartialPy):
463    partial = PyPartialSubclass
464
465class TestPartialMethod(unittest.TestCase):
466
467    class A(object):
468        nothing = functools.partialmethod(capture)
469        positional = functools.partialmethod(capture, 1)
470        keywords = functools.partialmethod(capture, a=2)
471        both = functools.partialmethod(capture, 3, b=4)
472        spec_keywords = functools.partialmethod(capture, self=1, func=2)
473
474        nested = functools.partialmethod(positional, 5)
475
476        over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
477
478        static = functools.partialmethod(staticmethod(capture), 8)
479        cls = functools.partialmethod(classmethod(capture), d=9)
480
481    a = A()
482
483    def test_arg_combinations(self):
484        self.assertEqual(self.a.nothing(), ((self.a,), {}))
485        self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
486        self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
487        self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
488
489        self.assertEqual(self.a.positional(), ((self.a, 1), {}))
490        self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
491        self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
492        self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
493
494        self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
495        self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
496        self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
497        self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
498
499        self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
500        self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
501        self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
502        self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
503
504        self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
505
506        self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
507
508    def test_nested(self):
509        self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
510        self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
511        self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
512        self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
513
514        self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
515
516    def test_over_partial(self):
517        self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
518        self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
519        self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
520        self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
521
522        self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
523
524    def test_bound_method_introspection(self):
525        obj = self.a
526        self.assertIs(obj.both.__self__, obj)
527        self.assertIs(obj.nested.__self__, obj)
528        self.assertIs(obj.over_partial.__self__, obj)
529        self.assertIs(obj.cls.__self__, self.A)
530        self.assertIs(self.A.cls.__self__, self.A)
531
532    def test_unbound_method_retrieval(self):
533        obj = self.A
534        self.assertFalse(hasattr(obj.both, "__self__"))
535        self.assertFalse(hasattr(obj.nested, "__self__"))
536        self.assertFalse(hasattr(obj.over_partial, "__self__"))
537        self.assertFalse(hasattr(obj.static, "__self__"))
538        self.assertFalse(hasattr(self.a.static, "__self__"))
539
540    def test_descriptors(self):
541        for obj in [self.A, self.a]:
542            with self.subTest(obj=obj):
543                self.assertEqual(obj.static(), ((8,), {}))
544                self.assertEqual(obj.static(5), ((8, 5), {}))
545                self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
546                self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
547
548                self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
549                self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
550                self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
551                self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
552
553    def test_overriding_keywords(self):
554        self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
555        self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
556
557    def test_invalid_args(self):
558        with self.assertRaises(TypeError):
559            class B(object):
560                method = functools.partialmethod(None, 1)
561        with self.assertRaises(TypeError):
562            class B:
563                method = functools.partialmethod()
564        with self.assertRaises(TypeError):
565            class B:
566                method = functools.partialmethod(func=capture, a=1)
567
568    def test_repr(self):
569        self.assertEqual(repr(vars(self.A)['both']),
570                         'functools.partialmethod({}, 3, b=4)'.format(capture))
571
572    def test_abstract(self):
573        class Abstract(abc.ABCMeta):
574
575            @abc.abstractmethod
576            def add(self, x, y):
577                pass
578
579            add5 = functools.partialmethod(add, 5)
580
581        self.assertTrue(Abstract.add.__isabstractmethod__)
582        self.assertTrue(Abstract.add5.__isabstractmethod__)
583
584        for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
585            self.assertFalse(getattr(func, '__isabstractmethod__', False))
586
587    def test_positional_only(self):
588        def f(a, b, /):
589            return a + b
590
591        p = functools.partial(f, 1)
592        self.assertEqual(p(2), f(1, 2))
593
594
595class TestUpdateWrapper(unittest.TestCase):
596
597    def check_wrapper(self, wrapper, wrapped,
598                      assigned=functools.WRAPPER_ASSIGNMENTS,
599                      updated=functools.WRAPPER_UPDATES):
600        # Check attributes were assigned
601        for name in assigned:
602            self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
603        # Check attributes were updated
604        for name in updated:
605            wrapper_attr = getattr(wrapper, name)
606            wrapped_attr = getattr(wrapped, name)
607            for key in wrapped_attr:
608                if name == "__dict__" and key == "__wrapped__":
609                    # __wrapped__ is overwritten by the update code
610                    continue
611                self.assertIs(wrapped_attr[key], wrapper_attr[key])
612        # Check __wrapped__
613        self.assertIs(wrapper.__wrapped__, wrapped)
614
615
616    def _default_update(self):
617        def f(a:'This is a new annotation'):
618            """This is a test"""
619            pass
620        f.attr = 'This is also a test'
621        f.__wrapped__ = "This is a bald faced lie"
622        def wrapper(b:'This is the prior annotation'):
623            pass
624        functools.update_wrapper(wrapper, f)
625        return wrapper, f
626
627    def test_default_update(self):
628        wrapper, f = self._default_update()
629        self.check_wrapper(wrapper, f)
630        self.assertIs(wrapper.__wrapped__, f)
631        self.assertEqual(wrapper.__name__, 'f')
632        self.assertEqual(wrapper.__qualname__, f.__qualname__)
633        self.assertEqual(wrapper.attr, 'This is also a test')
634        self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
635        self.assertNotIn('b', wrapper.__annotations__)
636
637    @unittest.skipIf(sys.flags.optimize >= 2,
638                     "Docstrings are omitted with -O2 and above")
639    def test_default_update_doc(self):
640        wrapper, f = self._default_update()
641        self.assertEqual(wrapper.__doc__, 'This is a test')
642
643    def test_no_update(self):
644        def f():
645            """This is a test"""
646            pass
647        f.attr = 'This is also a test'
648        def wrapper():
649            pass
650        functools.update_wrapper(wrapper, f, (), ())
651        self.check_wrapper(wrapper, f, (), ())
652        self.assertEqual(wrapper.__name__, 'wrapper')
653        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
654        self.assertEqual(wrapper.__doc__, None)
655        self.assertEqual(wrapper.__annotations__, {})
656        self.assertFalse(hasattr(wrapper, 'attr'))
657
658    def test_selective_update(self):
659        def f():
660            pass
661        f.attr = 'This is a different test'
662        f.dict_attr = dict(a=1, b=2, c=3)
663        def wrapper():
664            pass
665        wrapper.dict_attr = {}
666        assign = ('attr',)
667        update = ('dict_attr',)
668        functools.update_wrapper(wrapper, f, assign, update)
669        self.check_wrapper(wrapper, f, assign, update)
670        self.assertEqual(wrapper.__name__, 'wrapper')
671        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
672        self.assertEqual(wrapper.__doc__, None)
673        self.assertEqual(wrapper.attr, 'This is a different test')
674        self.assertEqual(wrapper.dict_attr, f.dict_attr)
675
676    def test_missing_attributes(self):
677        def f():
678            pass
679        def wrapper():
680            pass
681        wrapper.dict_attr = {}
682        assign = ('attr',)
683        update = ('dict_attr',)
684        # Missing attributes on wrapped object are ignored
685        functools.update_wrapper(wrapper, f, assign, update)
686        self.assertNotIn('attr', wrapper.__dict__)
687        self.assertEqual(wrapper.dict_attr, {})
688        # Wrapper must have expected attributes for updating
689        del wrapper.dict_attr
690        with self.assertRaises(AttributeError):
691            functools.update_wrapper(wrapper, f, assign, update)
692        wrapper.dict_attr = 1
693        with self.assertRaises(AttributeError):
694            functools.update_wrapper(wrapper, f, assign, update)
695
696    @support.requires_docstrings
697    @unittest.skipIf(sys.flags.optimize >= 2,
698                     "Docstrings are omitted with -O2 and above")
699    def test_builtin_update(self):
700        # Test for bug #1576241
701        def wrapper():
702            pass
703        functools.update_wrapper(wrapper, max)
704        self.assertEqual(wrapper.__name__, 'max')
705        self.assertTrue(wrapper.__doc__.startswith('max('))
706        self.assertEqual(wrapper.__annotations__, {})
707
708
709class TestWraps(TestUpdateWrapper):
710
711    def _default_update(self):
712        def f():
713            """This is a test"""
714            pass
715        f.attr = 'This is also a test'
716        f.__wrapped__ = "This is still a bald faced lie"
717        @functools.wraps(f)
718        def wrapper():
719            pass
720        return wrapper, f
721
722    def test_default_update(self):
723        wrapper, f = self._default_update()
724        self.check_wrapper(wrapper, f)
725        self.assertEqual(wrapper.__name__, 'f')
726        self.assertEqual(wrapper.__qualname__, f.__qualname__)
727        self.assertEqual(wrapper.attr, 'This is also a test')
728
729    @unittest.skipIf(sys.flags.optimize >= 2,
730                     "Docstrings are omitted with -O2 and above")
731    def test_default_update_doc(self):
732        wrapper, _ = self._default_update()
733        self.assertEqual(wrapper.__doc__, 'This is a test')
734
735    def test_no_update(self):
736        def f():
737            """This is a test"""
738            pass
739        f.attr = 'This is also a test'
740        @functools.wraps(f, (), ())
741        def wrapper():
742            pass
743        self.check_wrapper(wrapper, f, (), ())
744        self.assertEqual(wrapper.__name__, 'wrapper')
745        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
746        self.assertEqual(wrapper.__doc__, None)
747        self.assertFalse(hasattr(wrapper, 'attr'))
748
749    def test_selective_update(self):
750        def f():
751            pass
752        f.attr = 'This is a different test'
753        f.dict_attr = dict(a=1, b=2, c=3)
754        def add_dict_attr(f):
755            f.dict_attr = {}
756            return f
757        assign = ('attr',)
758        update = ('dict_attr',)
759        @functools.wraps(f, assign, update)
760        @add_dict_attr
761        def wrapper():
762            pass
763        self.check_wrapper(wrapper, f, assign, update)
764        self.assertEqual(wrapper.__name__, 'wrapper')
765        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
766        self.assertEqual(wrapper.__doc__, None)
767        self.assertEqual(wrapper.attr, 'This is a different test')
768        self.assertEqual(wrapper.dict_attr, f.dict_attr)
769
770
771class TestReduce:
772    def test_reduce(self):
773        class Squares:
774            def __init__(self, max):
775                self.max = max
776                self.sofar = []
777
778            def __len__(self):
779                return len(self.sofar)
780
781            def __getitem__(self, i):
782                if not 0 <= i < self.max: raise IndexError
783                n = len(self.sofar)
784                while n <= i:
785                    self.sofar.append(n*n)
786                    n += 1
787                return self.sofar[i]
788        def add(x, y):
789            return x + y
790        self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
791        self.assertEqual(
792            self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
793            ['a','c','d','w']
794        )
795        self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
796        self.assertEqual(
797            self.reduce(lambda x, y: x*y, range(2,21), 1),
798            2432902008176640000
799        )
800        self.assertEqual(self.reduce(add, Squares(10)), 285)
801        self.assertEqual(self.reduce(add, Squares(10), 0), 285)
802        self.assertEqual(self.reduce(add, Squares(0), 0), 0)
803        self.assertRaises(TypeError, self.reduce)
804        self.assertRaises(TypeError, self.reduce, 42, 42)
805        self.assertRaises(TypeError, self.reduce, 42, 42, 42)
806        self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
807        self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
808        self.assertRaises(TypeError, self.reduce, 42, (42, 42))
809        self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
810        self.assertRaises(TypeError, self.reduce, add, "")
811        self.assertRaises(TypeError, self.reduce, add, ())
812        self.assertRaises(TypeError, self.reduce, add, object())
813
814        class TestFailingIter:
815            def __iter__(self):
816                raise RuntimeError
817        self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
818
819        self.assertEqual(self.reduce(add, [], None), None)
820        self.assertEqual(self.reduce(add, [], 42), 42)
821
822        class BadSeq:
823            def __getitem__(self, index):
824                raise ValueError
825        self.assertRaises(ValueError, self.reduce, 42, BadSeq())
826
827    # Test reduce()'s use of iterators.
828    def test_iterator_usage(self):
829        class SequenceClass:
830            def __init__(self, n):
831                self.n = n
832            def __getitem__(self, i):
833                if 0 <= i < self.n:
834                    return i
835                else:
836                    raise IndexError
837
838        from operator import add
839        self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
840        self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
841        self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
842        self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
843        self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
844        self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
845
846        d = {"one": 1, "two": 2, "three": 3}
847        self.assertEqual(self.reduce(add, d), "".join(d.keys()))
848
849
850@unittest.skipUnless(c_functools, 'requires the C _functools module')
851class TestReduceC(TestReduce, unittest.TestCase):
852    if c_functools:
853        reduce = c_functools.reduce
854
855
856class TestReducePy(TestReduce, unittest.TestCase):
857    reduce = staticmethod(py_functools.reduce)
858
859
860class TestCmpToKey:
861
862    def test_cmp_to_key(self):
863        def cmp1(x, y):
864            return (x > y) - (x < y)
865        key = self.cmp_to_key(cmp1)
866        self.assertEqual(key(3), key(3))
867        self.assertGreater(key(3), key(1))
868        self.assertGreaterEqual(key(3), key(3))
869
870        def cmp2(x, y):
871            return int(x) - int(y)
872        key = self.cmp_to_key(cmp2)
873        self.assertEqual(key(4.0), key('4'))
874        self.assertLess(key(2), key('35'))
875        self.assertLessEqual(key(2), key('35'))
876        self.assertNotEqual(key(2), key('35'))
877
878    def test_cmp_to_key_arguments(self):
879        def cmp1(x, y):
880            return (x > y) - (x < y)
881        key = self.cmp_to_key(mycmp=cmp1)
882        self.assertEqual(key(obj=3), key(obj=3))
883        self.assertGreater(key(obj=3), key(obj=1))
884        with self.assertRaises((TypeError, AttributeError)):
885            key(3) > 1    # rhs is not a K object
886        with self.assertRaises((TypeError, AttributeError)):
887            1 < key(3)    # lhs is not a K object
888        with self.assertRaises(TypeError):
889            key = self.cmp_to_key()             # too few args
890        with self.assertRaises(TypeError):
891            key = self.cmp_to_key(cmp1, None)   # too many args
892        key = self.cmp_to_key(cmp1)
893        with self.assertRaises(TypeError):
894            key()                                    # too few args
895        with self.assertRaises(TypeError):
896            key(None, None)                          # too many args
897
898    def test_bad_cmp(self):
899        def cmp1(x, y):
900            raise ZeroDivisionError
901        key = self.cmp_to_key(cmp1)
902        with self.assertRaises(ZeroDivisionError):
903            key(3) > key(1)
904
905        class BadCmp:
906            def __lt__(self, other):
907                raise ZeroDivisionError
908        def cmp1(x, y):
909            return BadCmp()
910        with self.assertRaises(ZeroDivisionError):
911            key(3) > key(1)
912
913    def test_obj_field(self):
914        def cmp1(x, y):
915            return (x > y) - (x < y)
916        key = self.cmp_to_key(mycmp=cmp1)
917        self.assertEqual(key(50).obj, 50)
918
919    def test_sort_int(self):
920        def mycmp(x, y):
921            return y - x
922        self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
923                         [4, 3, 2, 1, 0])
924
925    def test_sort_int_str(self):
926        def mycmp(x, y):
927            x, y = int(x), int(y)
928            return (x > y) - (x < y)
929        values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
930        values = sorted(values, key=self.cmp_to_key(mycmp))
931        self.assertEqual([int(value) for value in values],
932                         [0, 1, 1, 2, 3, 4, 5, 7, 10])
933
934    def test_hash(self):
935        def mycmp(x, y):
936            return y - x
937        key = self.cmp_to_key(mycmp)
938        k = key(10)
939        self.assertRaises(TypeError, hash, k)
940        self.assertNotIsInstance(k, collections.abc.Hashable)
941
942
943@unittest.skipUnless(c_functools, 'requires the C _functools module')
944class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
945    if c_functools:
946        cmp_to_key = c_functools.cmp_to_key
947
948
949class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
950    cmp_to_key = staticmethod(py_functools.cmp_to_key)
951
952
953class TestTotalOrdering(unittest.TestCase):
954
955    def test_total_ordering_lt(self):
956        @functools.total_ordering
957        class A:
958            def __init__(self, value):
959                self.value = value
960            def __lt__(self, other):
961                return self.value < other.value
962            def __eq__(self, other):
963                return self.value == other.value
964        self.assertTrue(A(1) < A(2))
965        self.assertTrue(A(2) > A(1))
966        self.assertTrue(A(1) <= A(2))
967        self.assertTrue(A(2) >= A(1))
968        self.assertTrue(A(2) <= A(2))
969        self.assertTrue(A(2) >= A(2))
970        self.assertFalse(A(1) > A(2))
971
972    def test_total_ordering_le(self):
973        @functools.total_ordering
974        class A:
975            def __init__(self, value):
976                self.value = value
977            def __le__(self, other):
978                return self.value <= other.value
979            def __eq__(self, other):
980                return self.value == other.value
981        self.assertTrue(A(1) < A(2))
982        self.assertTrue(A(2) > A(1))
983        self.assertTrue(A(1) <= A(2))
984        self.assertTrue(A(2) >= A(1))
985        self.assertTrue(A(2) <= A(2))
986        self.assertTrue(A(2) >= A(2))
987        self.assertFalse(A(1) >= A(2))
988
989    def test_total_ordering_gt(self):
990        @functools.total_ordering
991        class A:
992            def __init__(self, value):
993                self.value = value
994            def __gt__(self, other):
995                return self.value > other.value
996            def __eq__(self, other):
997                return self.value == other.value
998        self.assertTrue(A(1) < A(2))
999        self.assertTrue(A(2) > A(1))
1000        self.assertTrue(A(1) <= A(2))
1001        self.assertTrue(A(2) >= A(1))
1002        self.assertTrue(A(2) <= A(2))
1003        self.assertTrue(A(2) >= A(2))
1004        self.assertFalse(A(2) < A(1))
1005
1006    def test_total_ordering_ge(self):
1007        @functools.total_ordering
1008        class A:
1009            def __init__(self, value):
1010                self.value = value
1011            def __ge__(self, other):
1012                return self.value >= other.value
1013            def __eq__(self, other):
1014                return self.value == other.value
1015        self.assertTrue(A(1) < A(2))
1016        self.assertTrue(A(2) > A(1))
1017        self.assertTrue(A(1) <= A(2))
1018        self.assertTrue(A(2) >= A(1))
1019        self.assertTrue(A(2) <= A(2))
1020        self.assertTrue(A(2) >= A(2))
1021        self.assertFalse(A(2) <= A(1))
1022
1023    def test_total_ordering_no_overwrite(self):
1024        # new methods should not overwrite existing
1025        @functools.total_ordering
1026        class A(int):
1027            pass
1028        self.assertTrue(A(1) < A(2))
1029        self.assertTrue(A(2) > A(1))
1030        self.assertTrue(A(1) <= A(2))
1031        self.assertTrue(A(2) >= A(1))
1032        self.assertTrue(A(2) <= A(2))
1033        self.assertTrue(A(2) >= A(2))
1034
1035    def test_no_operations_defined(self):
1036        with self.assertRaises(ValueError):
1037            @functools.total_ordering
1038            class A:
1039                pass
1040
1041    def test_type_error_when_not_implemented(self):
1042        # bug 10042; ensure stack overflow does not occur
1043        # when decorated types return NotImplemented
1044        @functools.total_ordering
1045        class ImplementsLessThan:
1046            def __init__(self, value):
1047                self.value = value
1048            def __eq__(self, other):
1049                if isinstance(other, ImplementsLessThan):
1050                    return self.value == other.value
1051                return False
1052            def __lt__(self, other):
1053                if isinstance(other, ImplementsLessThan):
1054                    return self.value < other.value
1055                return NotImplemented
1056
1057        @functools.total_ordering
1058        class ImplementsGreaterThan:
1059            def __init__(self, value):
1060                self.value = value
1061            def __eq__(self, other):
1062                if isinstance(other, ImplementsGreaterThan):
1063                    return self.value == other.value
1064                return False
1065            def __gt__(self, other):
1066                if isinstance(other, ImplementsGreaterThan):
1067                    return self.value > other.value
1068                return NotImplemented
1069
1070        @functools.total_ordering
1071        class ImplementsLessThanEqualTo:
1072            def __init__(self, value):
1073                self.value = value
1074            def __eq__(self, other):
1075                if isinstance(other, ImplementsLessThanEqualTo):
1076                    return self.value == other.value
1077                return False
1078            def __le__(self, other):
1079                if isinstance(other, ImplementsLessThanEqualTo):
1080                    return self.value <= other.value
1081                return NotImplemented
1082
1083        @functools.total_ordering
1084        class ImplementsGreaterThanEqualTo:
1085            def __init__(self, value):
1086                self.value = value
1087            def __eq__(self, other):
1088                if isinstance(other, ImplementsGreaterThanEqualTo):
1089                    return self.value == other.value
1090                return False
1091            def __ge__(self, other):
1092                if isinstance(other, ImplementsGreaterThanEqualTo):
1093                    return self.value >= other.value
1094                return NotImplemented
1095
1096        @functools.total_ordering
1097        class ComparatorNotImplemented:
1098            def __init__(self, value):
1099                self.value = value
1100            def __eq__(self, other):
1101                if isinstance(other, ComparatorNotImplemented):
1102                    return self.value == other.value
1103                return False
1104            def __lt__(self, other):
1105                return NotImplemented
1106
1107        with self.subTest("LT < 1"), self.assertRaises(TypeError):
1108            ImplementsLessThan(-1) < 1
1109
1110        with self.subTest("LT < LE"), self.assertRaises(TypeError):
1111            ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1112
1113        with self.subTest("LT < GT"), self.assertRaises(TypeError):
1114            ImplementsLessThan(1) < ImplementsGreaterThan(1)
1115
1116        with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1117            ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1118
1119        with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1120            ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1121
1122        with self.subTest("GT > GE"), self.assertRaises(TypeError):
1123            ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1124
1125        with self.subTest("GT > LT"), self.assertRaises(TypeError):
1126            ImplementsGreaterThan(5) > ImplementsLessThan(5)
1127
1128        with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1129            ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1130
1131        with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1132            ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1133
1134        with self.subTest("GE when equal"):
1135            a = ComparatorNotImplemented(8)
1136            b = ComparatorNotImplemented(8)
1137            self.assertEqual(a, b)
1138            with self.assertRaises(TypeError):
1139                a >= b
1140
1141        with self.subTest("LE when equal"):
1142            a = ComparatorNotImplemented(9)
1143            b = ComparatorNotImplemented(9)
1144            self.assertEqual(a, b)
1145            with self.assertRaises(TypeError):
1146                a <= b
1147
1148    def test_pickle(self):
1149        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1150            for name in '__lt__', '__gt__', '__le__', '__ge__':
1151                with self.subTest(method=name, proto=proto):
1152                    method = getattr(Orderable_LT, name)
1153                    method_copy = pickle.loads(pickle.dumps(method, proto))
1154                    self.assertIs(method_copy, method)
1155
1156@functools.total_ordering
1157class Orderable_LT:
1158    def __init__(self, value):
1159        self.value = value
1160    def __lt__(self, other):
1161        return self.value < other.value
1162    def __eq__(self, other):
1163        return self.value == other.value
1164
1165
1166class TestCache:
1167    # This tests that the pass-through is working as designed.
1168    # The underlying functionality is tested in TestLRU.
1169
1170    def test_cache(self):
1171        @self.module.cache
1172        def fib(n):
1173            if n < 2:
1174                return n
1175            return fib(n-1) + fib(n-2)
1176        self.assertEqual([fib(n) for n in range(16)],
1177            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1178        self.assertEqual(fib.cache_info(),
1179            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1180        fib.cache_clear()
1181        self.assertEqual(fib.cache_info(),
1182            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1183
1184
1185class TestLRU:
1186
1187    def test_lru(self):
1188        def orig(x, y):
1189            return 3 * x + y
1190        f = self.module.lru_cache(maxsize=20)(orig)
1191        hits, misses, maxsize, currsize = f.cache_info()
1192        self.assertEqual(maxsize, 20)
1193        self.assertEqual(currsize, 0)
1194        self.assertEqual(hits, 0)
1195        self.assertEqual(misses, 0)
1196
1197        domain = range(5)
1198        for i in range(1000):
1199            x, y = choice(domain), choice(domain)
1200            actual = f(x, y)
1201            expected = orig(x, y)
1202            self.assertEqual(actual, expected)
1203        hits, misses, maxsize, currsize = f.cache_info()
1204        self.assertTrue(hits > misses)
1205        self.assertEqual(hits + misses, 1000)
1206        self.assertEqual(currsize, 20)
1207
1208        f.cache_clear()   # test clearing
1209        hits, misses, maxsize, currsize = f.cache_info()
1210        self.assertEqual(hits, 0)
1211        self.assertEqual(misses, 0)
1212        self.assertEqual(currsize, 0)
1213        f(x, y)
1214        hits, misses, maxsize, currsize = f.cache_info()
1215        self.assertEqual(hits, 0)
1216        self.assertEqual(misses, 1)
1217        self.assertEqual(currsize, 1)
1218
1219        # Test bypassing the cache
1220        self.assertIs(f.__wrapped__, orig)
1221        f.__wrapped__(x, y)
1222        hits, misses, maxsize, currsize = f.cache_info()
1223        self.assertEqual(hits, 0)
1224        self.assertEqual(misses, 1)
1225        self.assertEqual(currsize, 1)
1226
1227        # test size zero (which means "never-cache")
1228        @self.module.lru_cache(0)
1229        def f():
1230            nonlocal f_cnt
1231            f_cnt += 1
1232            return 20
1233        self.assertEqual(f.cache_info().maxsize, 0)
1234        f_cnt = 0
1235        for i in range(5):
1236            self.assertEqual(f(), 20)
1237        self.assertEqual(f_cnt, 5)
1238        hits, misses, maxsize, currsize = f.cache_info()
1239        self.assertEqual(hits, 0)
1240        self.assertEqual(misses, 5)
1241        self.assertEqual(currsize, 0)
1242
1243        # test size one
1244        @self.module.lru_cache(1)
1245        def f():
1246            nonlocal f_cnt
1247            f_cnt += 1
1248            return 20
1249        self.assertEqual(f.cache_info().maxsize, 1)
1250        f_cnt = 0
1251        for i in range(5):
1252            self.assertEqual(f(), 20)
1253        self.assertEqual(f_cnt, 1)
1254        hits, misses, maxsize, currsize = f.cache_info()
1255        self.assertEqual(hits, 4)
1256        self.assertEqual(misses, 1)
1257        self.assertEqual(currsize, 1)
1258
1259        # test size two
1260        @self.module.lru_cache(2)
1261        def f(x):
1262            nonlocal f_cnt
1263            f_cnt += 1
1264            return x*10
1265        self.assertEqual(f.cache_info().maxsize, 2)
1266        f_cnt = 0
1267        for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1268            #    *  *              *                          *
1269            self.assertEqual(f(x), x*10)
1270        self.assertEqual(f_cnt, 4)
1271        hits, misses, maxsize, currsize = f.cache_info()
1272        self.assertEqual(hits, 12)
1273        self.assertEqual(misses, 4)
1274        self.assertEqual(currsize, 2)
1275
1276    def test_lru_no_args(self):
1277        @self.module.lru_cache
1278        def square(x):
1279            return x ** 2
1280
1281        self.assertEqual(list(map(square, [10, 20, 10])),
1282                         [100, 400, 100])
1283        self.assertEqual(square.cache_info().hits, 1)
1284        self.assertEqual(square.cache_info().misses, 2)
1285        self.assertEqual(square.cache_info().maxsize, 128)
1286        self.assertEqual(square.cache_info().currsize, 2)
1287
1288    def test_lru_bug_35780(self):
1289        # C version of the lru_cache was not checking to see if
1290        # the user function call has already modified the cache
1291        # (this arises in recursive calls and in multi-threading).
1292        # This cause the cache to have orphan links not referenced
1293        # by the cache dictionary.
1294
1295        once = True                 # Modified by f(x) below
1296
1297        @self.module.lru_cache(maxsize=10)
1298        def f(x):
1299            nonlocal once
1300            rv = f'.{x}.'
1301            if x == 20 and once:
1302                once = False
1303                rv = f(x)
1304            return rv
1305
1306        # Fill the cache
1307        for x in range(15):
1308            self.assertEqual(f(x), f'.{x}.')
1309        self.assertEqual(f.cache_info().currsize, 10)
1310
1311        # Make a recursive call and make sure the cache remains full
1312        self.assertEqual(f(20), '.20.')
1313        self.assertEqual(f.cache_info().currsize, 10)
1314
1315    def test_lru_bug_36650(self):
1316        # C version of lru_cache was treating a call with an empty **kwargs
1317        # dictionary as being distinct from a call with no keywords at all.
1318        # This did not result in an incorrect answer, but it did trigger
1319        # an unexpected cache miss.
1320
1321        @self.module.lru_cache()
1322        def f(x):
1323            pass
1324
1325        f(0)
1326        f(0, **{})
1327        self.assertEqual(f.cache_info().hits, 1)
1328
1329    def test_lru_hash_only_once(self):
1330        # To protect against weird reentrancy bugs and to improve
1331        # efficiency when faced with slow __hash__ methods, the
1332        # LRU cache guarantees that it will only call __hash__
1333        # only once per use as an argument to the cached function.
1334
1335        @self.module.lru_cache(maxsize=1)
1336        def f(x, y):
1337            return x * 3 + y
1338
1339        # Simulate the integer 5
1340        mock_int = unittest.mock.Mock()
1341        mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1342        mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1343
1344        # Add to cache:  One use as an argument gives one call
1345        self.assertEqual(f(mock_int, 1), 16)
1346        self.assertEqual(mock_int.__hash__.call_count, 1)
1347        self.assertEqual(f.cache_info(), (0, 1, 1, 1))
1348
1349        # Cache hit: One use as an argument gives one additional call
1350        self.assertEqual(f(mock_int, 1), 16)
1351        self.assertEqual(mock_int.__hash__.call_count, 2)
1352        self.assertEqual(f.cache_info(), (1, 1, 1, 1))
1353
1354        # Cache eviction: No use as an argument gives no additional call
1355        self.assertEqual(f(6, 2), 20)
1356        self.assertEqual(mock_int.__hash__.call_count, 2)
1357        self.assertEqual(f.cache_info(), (1, 2, 1, 1))
1358
1359        # Cache miss: One use as an argument gives one additional call
1360        self.assertEqual(f(mock_int, 1), 16)
1361        self.assertEqual(mock_int.__hash__.call_count, 3)
1362        self.assertEqual(f.cache_info(), (1, 3, 1, 1))
1363
1364    def test_lru_reentrancy_with_len(self):
1365        # Test to make sure the LRU cache code isn't thrown-off by
1366        # caching the built-in len() function.  Since len() can be
1367        # cached, we shouldn't use it inside the lru code itself.
1368        old_len = builtins.len
1369        try:
1370            builtins.len = self.module.lru_cache(4)(len)
1371            for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1372                self.assertEqual(len('abcdefghijklmn'[:i]), i)
1373        finally:
1374            builtins.len = old_len
1375
1376    def test_lru_star_arg_handling(self):
1377        # Test regression that arose in ea064ff3c10f
1378        @functools.lru_cache()
1379        def f(*args):
1380            return args
1381
1382        self.assertEqual(f(1, 2), (1, 2))
1383        self.assertEqual(f((1, 2)), ((1, 2),))
1384
1385    def test_lru_type_error(self):
1386        # Regression test for issue #28653.
1387        # lru_cache was leaking when one of the arguments
1388        # wasn't cacheable.
1389
1390        @functools.lru_cache(maxsize=None)
1391        def infinite_cache(o):
1392            pass
1393
1394        @functools.lru_cache(maxsize=10)
1395        def limited_cache(o):
1396            pass
1397
1398        with self.assertRaises(TypeError):
1399            infinite_cache([])
1400
1401        with self.assertRaises(TypeError):
1402            limited_cache([])
1403
1404    def test_lru_with_maxsize_none(self):
1405        @self.module.lru_cache(maxsize=None)
1406        def fib(n):
1407            if n < 2:
1408                return n
1409            return fib(n-1) + fib(n-2)
1410        self.assertEqual([fib(n) for n in range(16)],
1411            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1412        self.assertEqual(fib.cache_info(),
1413            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1414        fib.cache_clear()
1415        self.assertEqual(fib.cache_info(),
1416            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1417
1418    def test_lru_with_maxsize_negative(self):
1419        @self.module.lru_cache(maxsize=-10)
1420        def eq(n):
1421            return n
1422        for i in (0, 1):
1423            self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1424        self.assertEqual(eq.cache_info(),
1425            self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
1426
1427    def test_lru_with_exceptions(self):
1428        # Verify that user_function exceptions get passed through without
1429        # creating a hard-to-read chained exception.
1430        # http://bugs.python.org/issue13177
1431        for maxsize in (None, 128):
1432            @self.module.lru_cache(maxsize)
1433            def func(i):
1434                return 'abc'[i]
1435            self.assertEqual(func(0), 'a')
1436            with self.assertRaises(IndexError) as cm:
1437                func(15)
1438            self.assertIsNone(cm.exception.__context__)
1439            # Verify that the previous exception did not result in a cached entry
1440            with self.assertRaises(IndexError):
1441                func(15)
1442
1443    def test_lru_with_types(self):
1444        for maxsize in (None, 128):
1445            @self.module.lru_cache(maxsize=maxsize, typed=True)
1446            def square(x):
1447                return x * x
1448            self.assertEqual(square(3), 9)
1449            self.assertEqual(type(square(3)), type(9))
1450            self.assertEqual(square(3.0), 9.0)
1451            self.assertEqual(type(square(3.0)), type(9.0))
1452            self.assertEqual(square(x=3), 9)
1453            self.assertEqual(type(square(x=3)), type(9))
1454            self.assertEqual(square(x=3.0), 9.0)
1455            self.assertEqual(type(square(x=3.0)), type(9.0))
1456            self.assertEqual(square.cache_info().hits, 4)
1457            self.assertEqual(square.cache_info().misses, 4)
1458
1459    def test_lru_with_keyword_args(self):
1460        @self.module.lru_cache()
1461        def fib(n):
1462            if n < 2:
1463                return n
1464            return fib(n=n-1) + fib(n=n-2)
1465        self.assertEqual(
1466            [fib(n=number) for number in range(16)],
1467            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1468        )
1469        self.assertEqual(fib.cache_info(),
1470            self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1471        fib.cache_clear()
1472        self.assertEqual(fib.cache_info(),
1473            self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1474
1475    def test_lru_with_keyword_args_maxsize_none(self):
1476        @self.module.lru_cache(maxsize=None)
1477        def fib(n):
1478            if n < 2:
1479                return n
1480            return fib(n=n-1) + fib(n=n-2)
1481        self.assertEqual([fib(n=number) for number in range(16)],
1482            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1483        self.assertEqual(fib.cache_info(),
1484            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1485        fib.cache_clear()
1486        self.assertEqual(fib.cache_info(),
1487            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1488
1489    def test_kwargs_order(self):
1490        # PEP 468: Preserving Keyword Argument Order
1491        @self.module.lru_cache(maxsize=10)
1492        def f(**kwargs):
1493            return list(kwargs.items())
1494        self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1495        self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1496        self.assertEqual(f.cache_info(),
1497            self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1498
1499    def test_lru_cache_decoration(self):
1500        def f(zomg: 'zomg_annotation'):
1501            """f doc string"""
1502            return 42
1503        g = self.module.lru_cache()(f)
1504        for attr in self.module.WRAPPER_ASSIGNMENTS:
1505            self.assertEqual(getattr(g, attr), getattr(f, attr))
1506
1507    def test_lru_cache_threaded(self):
1508        n, m = 5, 11
1509        def orig(x, y):
1510            return 3 * x + y
1511        f = self.module.lru_cache(maxsize=n*m)(orig)
1512        hits, misses, maxsize, currsize = f.cache_info()
1513        self.assertEqual(currsize, 0)
1514
1515        start = threading.Event()
1516        def full(k):
1517            start.wait(10)
1518            for _ in range(m):
1519                self.assertEqual(f(k, 0), orig(k, 0))
1520
1521        def clear():
1522            start.wait(10)
1523            for _ in range(2*m):
1524                f.cache_clear()
1525
1526        orig_si = sys.getswitchinterval()
1527        support.setswitchinterval(1e-6)
1528        try:
1529            # create n threads in order to fill cache
1530            threads = [threading.Thread(target=full, args=[k])
1531                       for k in range(n)]
1532            with support.start_threads(threads):
1533                start.set()
1534
1535            hits, misses, maxsize, currsize = f.cache_info()
1536            if self.module is py_functools:
1537                # XXX: Why can be not equal?
1538                self.assertLessEqual(misses, n)
1539                self.assertLessEqual(hits, m*n - misses)
1540            else:
1541                self.assertEqual(misses, n)
1542                self.assertEqual(hits, m*n - misses)
1543            self.assertEqual(currsize, n)
1544
1545            # create n threads in order to fill cache and 1 to clear it
1546            threads = [threading.Thread(target=clear)]
1547            threads += [threading.Thread(target=full, args=[k])
1548                        for k in range(n)]
1549            start.clear()
1550            with support.start_threads(threads):
1551                start.set()
1552        finally:
1553            sys.setswitchinterval(orig_si)
1554
1555    def test_lru_cache_threaded2(self):
1556        # Simultaneous call with the same arguments
1557        n, m = 5, 7
1558        start = threading.Barrier(n+1)
1559        pause = threading.Barrier(n+1)
1560        stop = threading.Barrier(n+1)
1561        @self.module.lru_cache(maxsize=m*n)
1562        def f(x):
1563            pause.wait(10)
1564            return 3 * x
1565        self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1566        def test():
1567            for i in range(m):
1568                start.wait(10)
1569                self.assertEqual(f(i), 3 * i)
1570                stop.wait(10)
1571        threads = [threading.Thread(target=test) for k in range(n)]
1572        with support.start_threads(threads):
1573            for i in range(m):
1574                start.wait(10)
1575                stop.reset()
1576                pause.wait(10)
1577                start.reset()
1578                stop.wait(10)
1579                pause.reset()
1580                self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1581
1582    def test_lru_cache_threaded3(self):
1583        @self.module.lru_cache(maxsize=2)
1584        def f(x):
1585            time.sleep(.01)
1586            return 3 * x
1587        def test(i, x):
1588            with self.subTest(thread=i):
1589                self.assertEqual(f(x), 3 * x, i)
1590        threads = [threading.Thread(target=test, args=(i, v))
1591                   for i, v in enumerate([1, 2, 2, 3, 2])]
1592        with support.start_threads(threads):
1593            pass
1594
1595    def test_need_for_rlock(self):
1596        # This will deadlock on an LRU cache that uses a regular lock
1597
1598        @self.module.lru_cache(maxsize=10)
1599        def test_func(x):
1600            'Used to demonstrate a reentrant lru_cache call within a single thread'
1601            return x
1602
1603        class DoubleEq:
1604            'Demonstrate a reentrant lru_cache call within a single thread'
1605            def __init__(self, x):
1606                self.x = x
1607            def __hash__(self):
1608                return self.x
1609            def __eq__(self, other):
1610                if self.x == 2:
1611                    test_func(DoubleEq(1))
1612                return self.x == other.x
1613
1614        test_func(DoubleEq(1))                      # Load the cache
1615        test_func(DoubleEq(2))                      # Load the cache
1616        self.assertEqual(test_func(DoubleEq(2)),    # Trigger a re-entrant __eq__ call
1617                         DoubleEq(2))               # Verify the correct return value
1618
1619    def test_lru_method(self):
1620        class X(int):
1621            f_cnt = 0
1622            @self.module.lru_cache(2)
1623            def f(self, x):
1624                self.f_cnt += 1
1625                return x*10+self
1626        a = X(5)
1627        b = X(5)
1628        c = X(7)
1629        self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1630
1631        for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1632            self.assertEqual(a.f(x), x*10 + 5)
1633        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1634        self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1635
1636        for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1637            self.assertEqual(b.f(x), x*10 + 5)
1638        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1639        self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1640
1641        for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1642            self.assertEqual(c.f(x), x*10 + 7)
1643        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1644        self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1645
1646        self.assertEqual(a.f.cache_info(), X.f.cache_info())
1647        self.assertEqual(b.f.cache_info(), X.f.cache_info())
1648        self.assertEqual(c.f.cache_info(), X.f.cache_info())
1649
1650    def test_pickle(self):
1651        cls = self.__class__
1652        for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1653            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1654                with self.subTest(proto=proto, func=f):
1655                    f_copy = pickle.loads(pickle.dumps(f, proto))
1656                    self.assertIs(f_copy, f)
1657
1658    def test_copy(self):
1659        cls = self.__class__
1660        def orig(x, y):
1661            return 3 * x + y
1662        part = self.module.partial(orig, 2)
1663        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1664                 self.module.lru_cache(2)(part))
1665        for f in funcs:
1666            with self.subTest(func=f):
1667                f_copy = copy.copy(f)
1668                self.assertIs(f_copy, f)
1669
1670    def test_deepcopy(self):
1671        cls = self.__class__
1672        def orig(x, y):
1673            return 3 * x + y
1674        part = self.module.partial(orig, 2)
1675        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1676                 self.module.lru_cache(2)(part))
1677        for f in funcs:
1678            with self.subTest(func=f):
1679                f_copy = copy.deepcopy(f)
1680                self.assertIs(f_copy, f)
1681
1682    def test_lru_cache_parameters(self):
1683        @self.module.lru_cache(maxsize=2)
1684        def f():
1685            return 1
1686        self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1687
1688        @self.module.lru_cache(maxsize=1000, typed=True)
1689        def f():
1690            return 1
1691        self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1692
1693    def test_lru_cache_weakrefable(self):
1694        @self.module.lru_cache
1695        def test_function(x):
1696            return x
1697
1698        class A:
1699            @self.module.lru_cache
1700            def test_method(self, x):
1701                return (self, x)
1702
1703            @staticmethod
1704            @self.module.lru_cache
1705            def test_staticmethod(x):
1706                return (self, x)
1707
1708        refs = [weakref.ref(test_function),
1709                weakref.ref(A.test_method),
1710                weakref.ref(A.test_staticmethod)]
1711
1712        for ref in refs:
1713            self.assertIsNotNone(ref())
1714
1715        del A
1716        del test_function
1717        gc.collect()
1718
1719        for ref in refs:
1720            self.assertIsNone(ref())
1721
1722
1723@py_functools.lru_cache()
1724def py_cached_func(x, y):
1725    return 3 * x + y
1726
1727@c_functools.lru_cache()
1728def c_cached_func(x, y):
1729    return 3 * x + y
1730
1731
1732class TestLRUPy(TestLRU, unittest.TestCase):
1733    module = py_functools
1734    cached_func = py_cached_func,
1735
1736    @module.lru_cache()
1737    def cached_meth(self, x, y):
1738        return 3 * x + y
1739
1740    @staticmethod
1741    @module.lru_cache()
1742    def cached_staticmeth(x, y):
1743        return 3 * x + y
1744
1745
1746class TestLRUC(TestLRU, unittest.TestCase):
1747    module = c_functools
1748    cached_func = c_cached_func,
1749
1750    @module.lru_cache()
1751    def cached_meth(self, x, y):
1752        return 3 * x + y
1753
1754    @staticmethod
1755    @module.lru_cache()
1756    def cached_staticmeth(x, y):
1757        return 3 * x + y
1758
1759
1760class TestSingleDispatch(unittest.TestCase):
1761    def test_simple_overloads(self):
1762        @functools.singledispatch
1763        def g(obj):
1764            return "base"
1765        def g_int(i):
1766            return "integer"
1767        g.register(int, g_int)
1768        self.assertEqual(g("str"), "base")
1769        self.assertEqual(g(1), "integer")
1770        self.assertEqual(g([1,2,3]), "base")
1771
1772    def test_mro(self):
1773        @functools.singledispatch
1774        def g(obj):
1775            return "base"
1776        class A:
1777            pass
1778        class C(A):
1779            pass
1780        class B(A):
1781            pass
1782        class D(C, B):
1783            pass
1784        def g_A(a):
1785            return "A"
1786        def g_B(b):
1787            return "B"
1788        g.register(A, g_A)
1789        g.register(B, g_B)
1790        self.assertEqual(g(A()), "A")
1791        self.assertEqual(g(B()), "B")
1792        self.assertEqual(g(C()), "A")
1793        self.assertEqual(g(D()), "B")
1794
1795    def test_register_decorator(self):
1796        @functools.singledispatch
1797        def g(obj):
1798            return "base"
1799        @g.register(int)
1800        def g_int(i):
1801            return "int %s" % (i,)
1802        self.assertEqual(g(""), "base")
1803        self.assertEqual(g(12), "int 12")
1804        self.assertIs(g.dispatch(int), g_int)
1805        self.assertIs(g.dispatch(object), g.dispatch(str))
1806        # Note: in the assert above this is not g.
1807        # @singledispatch returns the wrapper.
1808
1809    def test_wrapping_attributes(self):
1810        @functools.singledispatch
1811        def g(obj):
1812            "Simple test"
1813            return "Test"
1814        self.assertEqual(g.__name__, "g")
1815        if sys.flags.optimize < 2:
1816            self.assertEqual(g.__doc__, "Simple test")
1817
1818    @unittest.skipUnless(decimal, 'requires _decimal')
1819    @support.cpython_only
1820    def test_c_classes(self):
1821        @functools.singledispatch
1822        def g(obj):
1823            return "base"
1824        @g.register(decimal.DecimalException)
1825        def _(obj):
1826            return obj.args
1827        subn = decimal.Subnormal("Exponent < Emin")
1828        rnd = decimal.Rounded("Number got rounded")
1829        self.assertEqual(g(subn), ("Exponent < Emin",))
1830        self.assertEqual(g(rnd), ("Number got rounded",))
1831        @g.register(decimal.Subnormal)
1832        def _(obj):
1833            return "Too small to care."
1834        self.assertEqual(g(subn), "Too small to care.")
1835        self.assertEqual(g(rnd), ("Number got rounded",))
1836
1837    def test_compose_mro(self):
1838        # None of the examples in this test depend on haystack ordering.
1839        c = collections.abc
1840        mro = functools._compose_mro
1841        bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1842        for haystack in permutations(bases):
1843            m = mro(dict, haystack)
1844            self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1845                                 c.Collection, c.Sized, c.Iterable,
1846                                 c.Container, object])
1847        bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
1848        for haystack in permutations(bases):
1849            m = mro(collections.ChainMap, haystack)
1850            self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
1851                                 c.Collection, c.Sized, c.Iterable,
1852                                 c.Container, object])
1853
1854        # If there's a generic function with implementations registered for
1855        # both Sized and Container, passing a defaultdict to it results in an
1856        # ambiguous dispatch which will cause a RuntimeError (see
1857        # test_mro_conflicts).
1858        bases = [c.Container, c.Sized, str]
1859        for haystack in permutations(bases):
1860            m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1861            self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1862                                 c.Container, object])
1863
1864        # MutableSequence below is registered directly on D. In other words, it
1865        # precedes MutableMapping which means single dispatch will always
1866        # choose MutableSequence here.
1867        class D(collections.defaultdict):
1868            pass
1869        c.MutableSequence.register(D)
1870        bases = [c.MutableSequence, c.MutableMapping]
1871        for haystack in permutations(bases):
1872            m = mro(D, bases)
1873            self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1874                                 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
1875                                 c.Collection, c.Sized, c.Iterable, c.Container,
1876                                 object])
1877
1878        # Container and Callable are registered on different base classes and
1879        # a generic function supporting both should always pick the Callable
1880        # implementation if a C instance is passed.
1881        class C(collections.defaultdict):
1882            def __call__(self):
1883                pass
1884        bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1885        for haystack in permutations(bases):
1886            m = mro(C, haystack)
1887            self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
1888                                 c.Collection, c.Sized, c.Iterable,
1889                                 c.Container, object])
1890
1891    def test_register_abc(self):
1892        c = collections.abc
1893        d = {"a": "b"}
1894        l = [1, 2, 3]
1895        s = {object(), None}
1896        f = frozenset(s)
1897        t = (1, 2, 3)
1898        @functools.singledispatch
1899        def g(obj):
1900            return "base"
1901        self.assertEqual(g(d), "base")
1902        self.assertEqual(g(l), "base")
1903        self.assertEqual(g(s), "base")
1904        self.assertEqual(g(f), "base")
1905        self.assertEqual(g(t), "base")
1906        g.register(c.Sized, lambda obj: "sized")
1907        self.assertEqual(g(d), "sized")
1908        self.assertEqual(g(l), "sized")
1909        self.assertEqual(g(s), "sized")
1910        self.assertEqual(g(f), "sized")
1911        self.assertEqual(g(t), "sized")
1912        g.register(c.MutableMapping, lambda obj: "mutablemapping")
1913        self.assertEqual(g(d), "mutablemapping")
1914        self.assertEqual(g(l), "sized")
1915        self.assertEqual(g(s), "sized")
1916        self.assertEqual(g(f), "sized")
1917        self.assertEqual(g(t), "sized")
1918        g.register(collections.ChainMap, lambda obj: "chainmap")
1919        self.assertEqual(g(d), "mutablemapping")  # irrelevant ABCs registered
1920        self.assertEqual(g(l), "sized")
1921        self.assertEqual(g(s), "sized")
1922        self.assertEqual(g(f), "sized")
1923        self.assertEqual(g(t), "sized")
1924        g.register(c.MutableSequence, lambda obj: "mutablesequence")
1925        self.assertEqual(g(d), "mutablemapping")
1926        self.assertEqual(g(l), "mutablesequence")
1927        self.assertEqual(g(s), "sized")
1928        self.assertEqual(g(f), "sized")
1929        self.assertEqual(g(t), "sized")
1930        g.register(c.MutableSet, lambda obj: "mutableset")
1931        self.assertEqual(g(d), "mutablemapping")
1932        self.assertEqual(g(l), "mutablesequence")
1933        self.assertEqual(g(s), "mutableset")
1934        self.assertEqual(g(f), "sized")
1935        self.assertEqual(g(t), "sized")
1936        g.register(c.Mapping, lambda obj: "mapping")
1937        self.assertEqual(g(d), "mutablemapping")  # not specific enough
1938        self.assertEqual(g(l), "mutablesequence")
1939        self.assertEqual(g(s), "mutableset")
1940        self.assertEqual(g(f), "sized")
1941        self.assertEqual(g(t), "sized")
1942        g.register(c.Sequence, lambda obj: "sequence")
1943        self.assertEqual(g(d), "mutablemapping")
1944        self.assertEqual(g(l), "mutablesequence")
1945        self.assertEqual(g(s), "mutableset")
1946        self.assertEqual(g(f), "sized")
1947        self.assertEqual(g(t), "sequence")
1948        g.register(c.Set, lambda obj: "set")
1949        self.assertEqual(g(d), "mutablemapping")
1950        self.assertEqual(g(l), "mutablesequence")
1951        self.assertEqual(g(s), "mutableset")
1952        self.assertEqual(g(f), "set")
1953        self.assertEqual(g(t), "sequence")
1954        g.register(dict, lambda obj: "dict")
1955        self.assertEqual(g(d), "dict")
1956        self.assertEqual(g(l), "mutablesequence")
1957        self.assertEqual(g(s), "mutableset")
1958        self.assertEqual(g(f), "set")
1959        self.assertEqual(g(t), "sequence")
1960        g.register(list, lambda obj: "list")
1961        self.assertEqual(g(d), "dict")
1962        self.assertEqual(g(l), "list")
1963        self.assertEqual(g(s), "mutableset")
1964        self.assertEqual(g(f), "set")
1965        self.assertEqual(g(t), "sequence")
1966        g.register(set, lambda obj: "concrete-set")
1967        self.assertEqual(g(d), "dict")
1968        self.assertEqual(g(l), "list")
1969        self.assertEqual(g(s), "concrete-set")
1970        self.assertEqual(g(f), "set")
1971        self.assertEqual(g(t), "sequence")
1972        g.register(frozenset, lambda obj: "frozen-set")
1973        self.assertEqual(g(d), "dict")
1974        self.assertEqual(g(l), "list")
1975        self.assertEqual(g(s), "concrete-set")
1976        self.assertEqual(g(f), "frozen-set")
1977        self.assertEqual(g(t), "sequence")
1978        g.register(tuple, lambda obj: "tuple")
1979        self.assertEqual(g(d), "dict")
1980        self.assertEqual(g(l), "list")
1981        self.assertEqual(g(s), "concrete-set")
1982        self.assertEqual(g(f), "frozen-set")
1983        self.assertEqual(g(t), "tuple")
1984
1985    def test_c3_abc(self):
1986        c = collections.abc
1987        mro = functools._c3_mro
1988        class A(object):
1989            pass
1990        class B(A):
1991            def __len__(self):
1992                return 0   # implies Sized
1993        @c.Container.register
1994        class C(object):
1995            pass
1996        class D(object):
1997            pass   # unrelated
1998        class X(D, C, B):
1999            def __call__(self):
2000                pass   # implies Callable
2001        expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2002        for abcs in permutations([c.Sized, c.Callable, c.Container]):
2003            self.assertEqual(mro(X, abcs=abcs), expected)
2004        # unrelated ABCs don't appear in the resulting MRO
2005        many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2006        self.assertEqual(mro(X, abcs=many_abcs), expected)
2007
2008    def test_false_meta(self):
2009        # see issue23572
2010        class MetaA(type):
2011            def __len__(self):
2012                return 0
2013        class A(metaclass=MetaA):
2014            pass
2015        class AA(A):
2016            pass
2017        @functools.singledispatch
2018        def fun(a):
2019            return 'base A'
2020        @fun.register(A)
2021        def _(a):
2022            return 'fun A'
2023        aa = AA()
2024        self.assertEqual(fun(aa), 'fun A')
2025
2026    def test_mro_conflicts(self):
2027        c = collections.abc
2028        @functools.singledispatch
2029        def g(arg):
2030            return "base"
2031        class O(c.Sized):
2032            def __len__(self):
2033                return 0
2034        o = O()
2035        self.assertEqual(g(o), "base")
2036        g.register(c.Iterable, lambda arg: "iterable")
2037        g.register(c.Container, lambda arg: "container")
2038        g.register(c.Sized, lambda arg: "sized")
2039        g.register(c.Set, lambda arg: "set")
2040        self.assertEqual(g(o), "sized")
2041        c.Iterable.register(O)
2042        self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__
2043        c.Container.register(O)
2044        self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__
2045        c.Set.register(O)
2046        self.assertEqual(g(o), "set")     # because c.Set is a subclass of
2047                                          # c.Sized and c.Container
2048        class P:
2049            pass
2050        p = P()
2051        self.assertEqual(g(p), "base")
2052        c.Iterable.register(P)
2053        self.assertEqual(g(p), "iterable")
2054        c.Container.register(P)
2055        with self.assertRaises(RuntimeError) as re_one:
2056            g(p)
2057        self.assertIn(
2058            str(re_one.exception),
2059            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2060              "or <class 'collections.abc.Iterable'>"),
2061             ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2062              "or <class 'collections.abc.Container'>")),
2063        )
2064        class Q(c.Sized):
2065            def __len__(self):
2066                return 0
2067        q = Q()
2068        self.assertEqual(g(q), "sized")
2069        c.Iterable.register(Q)
2070        self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__
2071        c.Set.register(Q)
2072        self.assertEqual(g(q), "set")     # because c.Set is a subclass of
2073                                          # c.Sized and c.Iterable
2074        @functools.singledispatch
2075        def h(arg):
2076            return "base"
2077        @h.register(c.Sized)
2078        def _(arg):
2079            return "sized"
2080        @h.register(c.Container)
2081        def _(arg):
2082            return "container"
2083        # Even though Sized and Container are explicit bases of MutableMapping,
2084        # this ABC is implicitly registered on defaultdict which makes all of
2085        # MutableMapping's bases implicit as well from defaultdict's
2086        # perspective.
2087        with self.assertRaises(RuntimeError) as re_two:
2088            h(collections.defaultdict(lambda: 0))
2089        self.assertIn(
2090            str(re_two.exception),
2091            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2092              "or <class 'collections.abc.Sized'>"),
2093             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2094              "or <class 'collections.abc.Container'>")),
2095        )
2096        class R(collections.defaultdict):
2097            pass
2098        c.MutableSequence.register(R)
2099        @functools.singledispatch
2100        def i(arg):
2101            return "base"
2102        @i.register(c.MutableMapping)
2103        def _(arg):
2104            return "mapping"
2105        @i.register(c.MutableSequence)
2106        def _(arg):
2107            return "sequence"
2108        r = R()
2109        self.assertEqual(i(r), "sequence")
2110        class S:
2111            pass
2112        class T(S, c.Sized):
2113            def __len__(self):
2114                return 0
2115        t = T()
2116        self.assertEqual(h(t), "sized")
2117        c.Container.register(T)
2118        self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO
2119        class U:
2120            def __len__(self):
2121                return 0
2122        u = U()
2123        self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred
2124                                          # from the existence of __len__()
2125        c.Container.register(U)
2126        # There is no preference for registered versus inferred ABCs.
2127        with self.assertRaises(RuntimeError) as re_three:
2128            h(u)
2129        self.assertIn(
2130            str(re_three.exception),
2131            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2132              "or <class 'collections.abc.Sized'>"),
2133             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2134              "or <class 'collections.abc.Container'>")),
2135        )
2136        class V(c.Sized, S):
2137            def __len__(self):
2138                return 0
2139        @functools.singledispatch
2140        def j(arg):
2141            return "base"
2142        @j.register(S)
2143        def _(arg):
2144            return "s"
2145        @j.register(c.Container)
2146        def _(arg):
2147            return "container"
2148        v = V()
2149        self.assertEqual(j(v), "s")
2150        c.Container.register(V)
2151        self.assertEqual(j(v), "container")   # because it ends up right after
2152                                              # Sized in the MRO
2153
2154    def test_cache_invalidation(self):
2155        from collections import UserDict
2156        import weakref
2157
2158        class TracingDict(UserDict):
2159            def __init__(self, *args, **kwargs):
2160                super(TracingDict, self).__init__(*args, **kwargs)
2161                self.set_ops = []
2162                self.get_ops = []
2163            def __getitem__(self, key):
2164                result = self.data[key]
2165                self.get_ops.append(key)
2166                return result
2167            def __setitem__(self, key, value):
2168                self.set_ops.append(key)
2169                self.data[key] = value
2170            def clear(self):
2171                self.data.clear()
2172
2173        td = TracingDict()
2174        with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2175            c = collections.abc
2176            @functools.singledispatch
2177            def g(arg):
2178                return "base"
2179            d = {}
2180            l = []
2181            self.assertEqual(len(td), 0)
2182            self.assertEqual(g(d), "base")
2183            self.assertEqual(len(td), 1)
2184            self.assertEqual(td.get_ops, [])
2185            self.assertEqual(td.set_ops, [dict])
2186            self.assertEqual(td.data[dict], g.registry[object])
2187            self.assertEqual(g(l), "base")
2188            self.assertEqual(len(td), 2)
2189            self.assertEqual(td.get_ops, [])
2190            self.assertEqual(td.set_ops, [dict, list])
2191            self.assertEqual(td.data[dict], g.registry[object])
2192            self.assertEqual(td.data[list], g.registry[object])
2193            self.assertEqual(td.data[dict], td.data[list])
2194            self.assertEqual(g(l), "base")
2195            self.assertEqual(g(d), "base")
2196            self.assertEqual(td.get_ops, [list, dict])
2197            self.assertEqual(td.set_ops, [dict, list])
2198            g.register(list, lambda arg: "list")
2199            self.assertEqual(td.get_ops, [list, dict])
2200            self.assertEqual(len(td), 0)
2201            self.assertEqual(g(d), "base")
2202            self.assertEqual(len(td), 1)
2203            self.assertEqual(td.get_ops, [list, dict])
2204            self.assertEqual(td.set_ops, [dict, list, dict])
2205            self.assertEqual(td.data[dict],
2206                             functools._find_impl(dict, g.registry))
2207            self.assertEqual(g(l), "list")
2208            self.assertEqual(len(td), 2)
2209            self.assertEqual(td.get_ops, [list, dict])
2210            self.assertEqual(td.set_ops, [dict, list, dict, list])
2211            self.assertEqual(td.data[list],
2212                             functools._find_impl(list, g.registry))
2213            class X:
2214                pass
2215            c.MutableMapping.register(X)   # Will not invalidate the cache,
2216                                           # not using ABCs yet.
2217            self.assertEqual(g(d), "base")
2218            self.assertEqual(g(l), "list")
2219            self.assertEqual(td.get_ops, [list, dict, dict, list])
2220            self.assertEqual(td.set_ops, [dict, list, dict, list])
2221            g.register(c.Sized, lambda arg: "sized")
2222            self.assertEqual(len(td), 0)
2223            self.assertEqual(g(d), "sized")
2224            self.assertEqual(len(td), 1)
2225            self.assertEqual(td.get_ops, [list, dict, dict, list])
2226            self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2227            self.assertEqual(g(l), "list")
2228            self.assertEqual(len(td), 2)
2229            self.assertEqual(td.get_ops, [list, dict, dict, list])
2230            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2231            self.assertEqual(g(l), "list")
2232            self.assertEqual(g(d), "sized")
2233            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2234            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2235            g.dispatch(list)
2236            g.dispatch(dict)
2237            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2238                                          list, dict])
2239            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2240            c.MutableSet.register(X)       # Will invalidate the cache.
2241            self.assertEqual(len(td), 2)   # Stale cache.
2242            self.assertEqual(g(l), "list")
2243            self.assertEqual(len(td), 1)
2244            g.register(c.MutableMapping, lambda arg: "mutablemapping")
2245            self.assertEqual(len(td), 0)
2246            self.assertEqual(g(d), "mutablemapping")
2247            self.assertEqual(len(td), 1)
2248            self.assertEqual(g(l), "list")
2249            self.assertEqual(len(td), 2)
2250            g.register(dict, lambda arg: "dict")
2251            self.assertEqual(g(d), "dict")
2252            self.assertEqual(g(l), "list")
2253            g._clear_cache()
2254            self.assertEqual(len(td), 0)
2255
2256    def test_annotations(self):
2257        @functools.singledispatch
2258        def i(arg):
2259            return "base"
2260        @i.register
2261        def _(arg: collections.abc.Mapping):
2262            return "mapping"
2263        @i.register
2264        def _(arg: "collections.abc.Sequence"):
2265            return "sequence"
2266        self.assertEqual(i(None), "base")
2267        self.assertEqual(i({"a": 1}), "mapping")
2268        self.assertEqual(i([1, 2, 3]), "sequence")
2269        self.assertEqual(i((1, 2, 3)), "sequence")
2270        self.assertEqual(i("str"), "sequence")
2271
2272        # Registering classes as callables doesn't work with annotations,
2273        # you need to pass the type explicitly.
2274        @i.register(str)
2275        class _:
2276            def __init__(self, arg):
2277                self.arg = arg
2278
2279            def __eq__(self, other):
2280                return self.arg == other
2281        self.assertEqual(i("str"), "str")
2282
2283    def test_method_register(self):
2284        class A:
2285            @functools.singledispatchmethod
2286            def t(self, arg):
2287                self.arg = "base"
2288            @t.register(int)
2289            def _(self, arg):
2290                self.arg = "int"
2291            @t.register(str)
2292            def _(self, arg):
2293                self.arg = "str"
2294        a = A()
2295
2296        a.t(0)
2297        self.assertEqual(a.arg, "int")
2298        aa = A()
2299        self.assertFalse(hasattr(aa, 'arg'))
2300        a.t('')
2301        self.assertEqual(a.arg, "str")
2302        aa = A()
2303        self.assertFalse(hasattr(aa, 'arg'))
2304        a.t(0.0)
2305        self.assertEqual(a.arg, "base")
2306        aa = A()
2307        self.assertFalse(hasattr(aa, 'arg'))
2308
2309    def test_staticmethod_register(self):
2310        class A:
2311            @functools.singledispatchmethod
2312            @staticmethod
2313            def t(arg):
2314                return arg
2315            @t.register(int)
2316            @staticmethod
2317            def _(arg):
2318                return isinstance(arg, int)
2319            @t.register(str)
2320            @staticmethod
2321            def _(arg):
2322                return isinstance(arg, str)
2323        a = A()
2324
2325        self.assertTrue(A.t(0))
2326        self.assertTrue(A.t(''))
2327        self.assertEqual(A.t(0.0), 0.0)
2328
2329    def test_classmethod_register(self):
2330        class A:
2331            def __init__(self, arg):
2332                self.arg = arg
2333
2334            @functools.singledispatchmethod
2335            @classmethod
2336            def t(cls, arg):
2337                return cls("base")
2338            @t.register(int)
2339            @classmethod
2340            def _(cls, arg):
2341                return cls("int")
2342            @t.register(str)
2343            @classmethod
2344            def _(cls, arg):
2345                return cls("str")
2346
2347        self.assertEqual(A.t(0).arg, "int")
2348        self.assertEqual(A.t('').arg, "str")
2349        self.assertEqual(A.t(0.0).arg, "base")
2350
2351    def test_callable_register(self):
2352        class A:
2353            def __init__(self, arg):
2354                self.arg = arg
2355
2356            @functools.singledispatchmethod
2357            @classmethod
2358            def t(cls, arg):
2359                return cls("base")
2360
2361        @A.t.register(int)
2362        @classmethod
2363        def _(cls, arg):
2364            return cls("int")
2365        @A.t.register(str)
2366        @classmethod
2367        def _(cls, arg):
2368            return cls("str")
2369
2370        self.assertEqual(A.t(0).arg, "int")
2371        self.assertEqual(A.t('').arg, "str")
2372        self.assertEqual(A.t(0.0).arg, "base")
2373
2374    def test_abstractmethod_register(self):
2375        class Abstract(abc.ABCMeta):
2376
2377            @functools.singledispatchmethod
2378            @abc.abstractmethod
2379            def add(self, x, y):
2380                pass
2381
2382        self.assertTrue(Abstract.add.__isabstractmethod__)
2383
2384    def test_type_ann_register(self):
2385        class A:
2386            @functools.singledispatchmethod
2387            def t(self, arg):
2388                return "base"
2389            @t.register
2390            def _(self, arg: int):
2391                return "int"
2392            @t.register
2393            def _(self, arg: str):
2394                return "str"
2395        a = A()
2396
2397        self.assertEqual(a.t(0), "int")
2398        self.assertEqual(a.t(''), "str")
2399        self.assertEqual(a.t(0.0), "base")
2400
2401    def test_invalid_registrations(self):
2402        msg_prefix = "Invalid first argument to `register()`: "
2403        msg_suffix = (
2404            ". Use either `@register(some_class)` or plain `@register` on an "
2405            "annotated function."
2406        )
2407        @functools.singledispatch
2408        def i(arg):
2409            return "base"
2410        with self.assertRaises(TypeError) as exc:
2411            @i.register(42)
2412            def _(arg):
2413                return "I annotated with a non-type"
2414        self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2415        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2416        with self.assertRaises(TypeError) as exc:
2417            @i.register
2418            def _(arg):
2419                return "I forgot to annotate"
2420        self.assertTrue(str(exc.exception).startswith(msg_prefix +
2421            "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2422        ))
2423        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2424
2425        with self.assertRaises(TypeError) as exc:
2426            @i.register
2427            def _(arg: typing.Iterable[str]):
2428                # At runtime, dispatching on generics is impossible.
2429                # When registering implementations with singledispatch, avoid
2430                # types from `typing`. Instead, annotate with regular types
2431                # or ABCs.
2432                return "I annotated with a generic collection"
2433        self.assertTrue(str(exc.exception).startswith(
2434            "Invalid annotation for 'arg'."
2435        ))
2436        self.assertTrue(str(exc.exception).endswith(
2437            'typing.Iterable[str] is not a class.'
2438        ))
2439
2440    def test_invalid_positional_argument(self):
2441        @functools.singledispatch
2442        def f(*args):
2443            pass
2444        msg = 'f requires at least 1 positional argument'
2445        with self.assertRaisesRegex(TypeError, msg):
2446            f()
2447
2448
2449class CachedCostItem:
2450    _cost = 1
2451
2452    def __init__(self):
2453        self.lock = py_functools.RLock()
2454
2455    @py_functools.cached_property
2456    def cost(self):
2457        """The cost of the item."""
2458        with self.lock:
2459            self._cost += 1
2460        return self._cost
2461
2462
2463class OptionallyCachedCostItem:
2464    _cost = 1
2465
2466    def get_cost(self):
2467        """The cost of the item."""
2468        self._cost += 1
2469        return self._cost
2470
2471    cached_cost = py_functools.cached_property(get_cost)
2472
2473
2474class CachedCostItemWait:
2475
2476    def __init__(self, event):
2477        self._cost = 1
2478        self.lock = py_functools.RLock()
2479        self.event = event
2480
2481    @py_functools.cached_property
2482    def cost(self):
2483        self.event.wait(1)
2484        with self.lock:
2485            self._cost += 1
2486        return self._cost
2487
2488
2489class CachedCostItemWithSlots:
2490    __slots__ = ('_cost')
2491
2492    def __init__(self):
2493        self._cost = 1
2494
2495    @py_functools.cached_property
2496    def cost(self):
2497        raise RuntimeError('never called, slots not supported')
2498
2499
2500class TestCachedProperty(unittest.TestCase):
2501    def test_cached(self):
2502        item = CachedCostItem()
2503        self.assertEqual(item.cost, 2)
2504        self.assertEqual(item.cost, 2) # not 3
2505
2506    def test_cached_attribute_name_differs_from_func_name(self):
2507        item = OptionallyCachedCostItem()
2508        self.assertEqual(item.get_cost(), 2)
2509        self.assertEqual(item.cached_cost, 3)
2510        self.assertEqual(item.get_cost(), 4)
2511        self.assertEqual(item.cached_cost, 3)
2512
2513    def test_threaded(self):
2514        go = threading.Event()
2515        item = CachedCostItemWait(go)
2516
2517        num_threads = 3
2518
2519        orig_si = sys.getswitchinterval()
2520        sys.setswitchinterval(1e-6)
2521        try:
2522            threads = [
2523                threading.Thread(target=lambda: item.cost)
2524                for k in range(num_threads)
2525            ]
2526            with support.start_threads(threads):
2527                go.set()
2528        finally:
2529            sys.setswitchinterval(orig_si)
2530
2531        self.assertEqual(item.cost, 2)
2532
2533    def test_object_with_slots(self):
2534        item = CachedCostItemWithSlots()
2535        with self.assertRaisesRegex(
2536                TypeError,
2537                "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2538        ):
2539            item.cost
2540
2541    def test_immutable_dict(self):
2542        class MyMeta(type):
2543            @py_functools.cached_property
2544            def prop(self):
2545                return True
2546
2547        class MyClass(metaclass=MyMeta):
2548            pass
2549
2550        with self.assertRaisesRegex(
2551            TypeError,
2552            "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2553        ):
2554            MyClass.prop
2555
2556    def test_reuse_different_names(self):
2557        """Disallow this case because decorated function a would not be cached."""
2558        with self.assertRaises(RuntimeError) as ctx:
2559            class ReusedCachedProperty:
2560                @py_functools.cached_property
2561                def a(self):
2562                    pass
2563
2564                b = a
2565
2566        self.assertEqual(
2567            str(ctx.exception.__context__),
2568            str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2569        )
2570
2571    def test_reuse_same_name(self):
2572        """Reusing a cached_property on different classes under the same name is OK."""
2573        counter = 0
2574
2575        @py_functools.cached_property
2576        def _cp(_self):
2577            nonlocal counter
2578            counter += 1
2579            return counter
2580
2581        class A:
2582            cp = _cp
2583
2584        class B:
2585            cp = _cp
2586
2587        a = A()
2588        b = B()
2589
2590        self.assertEqual(a.cp, 1)
2591        self.assertEqual(b.cp, 2)
2592        self.assertEqual(a.cp, 1)
2593
2594    def test_set_name_not_called(self):
2595        cp = py_functools.cached_property(lambda s: None)
2596        class Foo:
2597            pass
2598
2599        Foo.cp = cp
2600
2601        with self.assertRaisesRegex(
2602                TypeError,
2603                "Cannot use cached_property instance without calling __set_name__ on it.",
2604        ):
2605            Foo().cp
2606
2607    def test_access_from_class(self):
2608        self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2609
2610    def test_doc(self):
2611        self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2612
2613
2614if __name__ == '__main__':
2615    unittest.main()
2616