1import abc 2import builtins 3import collections 4import collections.abc 5import copy 6from itertools import permutations 7import pickle 8from random import choice 9import sys 10from test import support 11import threading 12import time 13import typing 14import unittest 15import unittest.mock 16import os 17import weakref 18import gc 19from weakref import proxy 20import contextlib 21 22from test.support.script_helper import assert_python_ok 23 24import functools 25 26py_functools = support.import_fresh_module('functools', blocked=['_functools']) 27c_functools = support.import_fresh_module('functools', fresh=['_functools']) 28 29decimal = support.import_fresh_module('decimal', fresh=['_decimal']) 30 31@contextlib.contextmanager 32def replaced_module(name, replacement): 33 original_module = sys.modules[name] 34 sys.modules[name] = replacement 35 try: 36 yield 37 finally: 38 sys.modules[name] = original_module 39 40def capture(*args, **kw): 41 """capture all positional and keyword arguments""" 42 return args, kw 43 44 45def signature(part): 46 """ return the signature of a partial object """ 47 return (part.func, part.args, part.keywords, part.__dict__) 48 49class MyTuple(tuple): 50 pass 51 52class BadTuple(tuple): 53 def __add__(self, other): 54 return list(self) + list(other) 55 56class MyDict(dict): 57 pass 58 59 60class TestPartial: 61 62 def test_basic_examples(self): 63 p = self.partial(capture, 1, 2, a=10, b=20) 64 self.assertTrue(callable(p)) 65 self.assertEqual(p(3, 4, b=30, c=40), 66 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) 67 p = self.partial(map, lambda x: x*10) 68 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) 69 70 def test_attributes(self): 71 p = self.partial(capture, 1, 2, a=10, b=20) 72 # attributes should be readable 73 self.assertEqual(p.func, capture) 74 self.assertEqual(p.args, (1, 2)) 75 self.assertEqual(p.keywords, dict(a=10, b=20)) 76 77 def test_argument_checking(self): 78 self.assertRaises(TypeError, self.partial) # need at least a func arg 79 try: 80 self.partial(2)() 81 except TypeError: 82 pass 83 else: 84 self.fail('First arg not checked for callability') 85 86 def test_protection_of_callers_dict_argument(self): 87 # a caller's dictionary should not be altered by partial 88 def func(a=10, b=20): 89 return a 90 d = {'a':3} 91 p = self.partial(func, a=5) 92 self.assertEqual(p(**d), 3) 93 self.assertEqual(d, {'a':3}) 94 p(b=7) 95 self.assertEqual(d, {'a':3}) 96 97 def test_kwargs_copy(self): 98 # Issue #29532: Altering a kwarg dictionary passed to a constructor 99 # should not affect a partial object after creation 100 d = {'a': 3} 101 p = self.partial(capture, **d) 102 self.assertEqual(p(), ((), {'a': 3})) 103 d['a'] = 5 104 self.assertEqual(p(), ((), {'a': 3})) 105 106 def test_arg_combinations(self): 107 # exercise special code paths for zero args in either partial 108 # object or the caller 109 p = self.partial(capture) 110 self.assertEqual(p(), ((), {})) 111 self.assertEqual(p(1,2), ((1,2), {})) 112 p = self.partial(capture, 1, 2) 113 self.assertEqual(p(), ((1,2), {})) 114 self.assertEqual(p(3,4), ((1,2,3,4), {})) 115 116 def test_kw_combinations(self): 117 # exercise special code paths for no keyword args in 118 # either the partial object or the caller 119 p = self.partial(capture) 120 self.assertEqual(p.keywords, {}) 121 self.assertEqual(p(), ((), {})) 122 self.assertEqual(p(a=1), ((), {'a':1})) 123 p = self.partial(capture, a=1) 124 self.assertEqual(p.keywords, {'a':1}) 125 self.assertEqual(p(), ((), {'a':1})) 126 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) 127 # keyword args in the call override those in the partial object 128 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) 129 130 def test_positional(self): 131 # make sure positional arguments are captured correctly 132 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: 133 p = self.partial(capture, *args) 134 expected = args + ('x',) 135 got, empty = p('x') 136 self.assertTrue(expected == got and empty == {}) 137 138 def test_keyword(self): 139 # make sure keyword arguments are captured correctly 140 for a in ['a', 0, None, 3.5]: 141 p = self.partial(capture, a=a) 142 expected = {'a':a,'x':None} 143 empty, got = p(x=None) 144 self.assertTrue(expected == got and empty == ()) 145 146 def test_no_side_effects(self): 147 # make sure there are no side effects that affect subsequent calls 148 p = self.partial(capture, 0, a=1) 149 args1, kw1 = p(1, b=2) 150 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) 151 args2, kw2 = p() 152 self.assertTrue(args2 == (0,) and kw2 == {'a':1}) 153 154 def test_error_propagation(self): 155 def f(x, y): 156 x / y 157 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) 158 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) 159 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) 160 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) 161 162 def test_weakref(self): 163 f = self.partial(int, base=16) 164 p = proxy(f) 165 self.assertEqual(f.func, p.func) 166 f = None 167 self.assertRaises(ReferenceError, getattr, p, 'func') 168 169 def test_with_bound_and_unbound_methods(self): 170 data = list(map(str, range(10))) 171 join = self.partial(str.join, '') 172 self.assertEqual(join(data), '0123456789') 173 join = self.partial(''.join) 174 self.assertEqual(join(data), '0123456789') 175 176 def test_nested_optimization(self): 177 partial = self.partial 178 inner = partial(signature, 'asdf') 179 nested = partial(inner, bar=True) 180 flat = partial(signature, 'asdf', bar=True) 181 self.assertEqual(signature(nested), signature(flat)) 182 183 def test_nested_partial_with_attribute(self): 184 # see issue 25137 185 partial = self.partial 186 187 def foo(bar): 188 return bar 189 190 p = partial(foo, 'first') 191 p2 = partial(p, 'second') 192 p2.new_attr = 'spam' 193 self.assertEqual(p2.new_attr, 'spam') 194 195 def test_repr(self): 196 args = (object(), object()) 197 args_repr = ', '.join(repr(a) for a in args) 198 kwargs = {'a': object(), 'b': object()} 199 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 200 'b={b!r}, a={a!r}'.format_map(kwargs)] 201 if self.partial in (c_functools.partial, py_functools.partial): 202 name = 'functools.partial' 203 else: 204 name = self.partial.__name__ 205 206 f = self.partial(capture) 207 self.assertEqual(f'{name}({capture!r})', repr(f)) 208 209 f = self.partial(capture, *args) 210 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f)) 211 212 f = self.partial(capture, **kwargs) 213 self.assertIn(repr(f), 214 [f'{name}({capture!r}, {kwargs_repr})' 215 for kwargs_repr in kwargs_reprs]) 216 217 f = self.partial(capture, *args, **kwargs) 218 self.assertIn(repr(f), 219 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})' 220 for kwargs_repr in kwargs_reprs]) 221 222 def test_recursive_repr(self): 223 if self.partial in (c_functools.partial, py_functools.partial): 224 name = 'functools.partial' 225 else: 226 name = self.partial.__name__ 227 228 f = self.partial(capture) 229 f.__setstate__((f, (), {}, {})) 230 try: 231 self.assertEqual(repr(f), '%s(...)' % (name,)) 232 finally: 233 f.__setstate__((capture, (), {}, {})) 234 235 f = self.partial(capture) 236 f.__setstate__((capture, (f,), {}, {})) 237 try: 238 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,)) 239 finally: 240 f.__setstate__((capture, (), {}, {})) 241 242 f = self.partial(capture) 243 f.__setstate__((capture, (), {'a': f}, {})) 244 try: 245 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,)) 246 finally: 247 f.__setstate__((capture, (), {}, {})) 248 249 def test_pickle(self): 250 with self.AllowPickle(): 251 f = self.partial(signature, ['asdf'], bar=[True]) 252 f.attr = [] 253 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 254 f_copy = pickle.loads(pickle.dumps(f, proto)) 255 self.assertEqual(signature(f_copy), signature(f)) 256 257 def test_copy(self): 258 f = self.partial(signature, ['asdf'], bar=[True]) 259 f.attr = [] 260 f_copy = copy.copy(f) 261 self.assertEqual(signature(f_copy), signature(f)) 262 self.assertIs(f_copy.attr, f.attr) 263 self.assertIs(f_copy.args, f.args) 264 self.assertIs(f_copy.keywords, f.keywords) 265 266 def test_deepcopy(self): 267 f = self.partial(signature, ['asdf'], bar=[True]) 268 f.attr = [] 269 f_copy = copy.deepcopy(f) 270 self.assertEqual(signature(f_copy), signature(f)) 271 self.assertIsNot(f_copy.attr, f.attr) 272 self.assertIsNot(f_copy.args, f.args) 273 self.assertIsNot(f_copy.args[0], f.args[0]) 274 self.assertIsNot(f_copy.keywords, f.keywords) 275 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) 276 277 def test_setstate(self): 278 f = self.partial(signature) 279 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) 280 281 self.assertEqual(signature(f), 282 (capture, (1,), dict(a=10), dict(attr=[]))) 283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 284 285 f.__setstate__((capture, (1,), dict(a=10), None)) 286 287 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) 288 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 289 290 f.__setstate__((capture, (1,), None, None)) 291 #self.assertEqual(signature(f), (capture, (1,), {}, {})) 292 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) 293 self.assertEqual(f(2), ((1, 2), {})) 294 self.assertEqual(f(), ((1,), {})) 295 296 f.__setstate__((capture, (), {}, None)) 297 self.assertEqual(signature(f), (capture, (), {}, {})) 298 self.assertEqual(f(2, b=20), ((2,), {'b': 20})) 299 self.assertEqual(f(2), ((2,), {})) 300 self.assertEqual(f(), ((), {})) 301 302 def test_setstate_errors(self): 303 f = self.partial(signature) 304 self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) 305 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) 306 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) 307 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) 308 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) 309 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) 310 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) 311 312 def test_setstate_subclasses(self): 313 f = self.partial(signature) 314 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) 315 s = signature(f) 316 self.assertEqual(s, (capture, (1,), dict(a=10), {})) 317 self.assertIs(type(s[1]), tuple) 318 self.assertIs(type(s[2]), dict) 319 r = f() 320 self.assertEqual(r, ((1,), {'a': 10})) 321 self.assertIs(type(r[0]), tuple) 322 self.assertIs(type(r[1]), dict) 323 324 f.__setstate__((capture, BadTuple((1,)), {}, None)) 325 s = signature(f) 326 self.assertEqual(s, (capture, (1,), {}, {})) 327 self.assertIs(type(s[1]), tuple) 328 r = f(2) 329 self.assertEqual(r, ((1, 2), {})) 330 self.assertIs(type(r[0]), tuple) 331 332 def test_recursive_pickle(self): 333 with self.AllowPickle(): 334 f = self.partial(capture) 335 f.__setstate__((f, (), {}, {})) 336 try: 337 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 338 with self.assertRaises(RecursionError): 339 pickle.dumps(f, proto) 340 finally: 341 f.__setstate__((capture, (), {}, {})) 342 343 f = self.partial(capture) 344 f.__setstate__((capture, (f,), {}, {})) 345 try: 346 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 347 f_copy = pickle.loads(pickle.dumps(f, proto)) 348 try: 349 self.assertIs(f_copy.args[0], f_copy) 350 finally: 351 f_copy.__setstate__((capture, (), {}, {})) 352 finally: 353 f.__setstate__((capture, (), {}, {})) 354 355 f = self.partial(capture) 356 f.__setstate__((capture, (), {'a': f}, {})) 357 try: 358 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 359 f_copy = pickle.loads(pickle.dumps(f, proto)) 360 try: 361 self.assertIs(f_copy.keywords['a'], f_copy) 362 finally: 363 f_copy.__setstate__((capture, (), {}, {})) 364 finally: 365 f.__setstate__((capture, (), {}, {})) 366 367 # Issue 6083: Reference counting bug 368 def test_setstate_refcount(self): 369 class BadSequence: 370 def __len__(self): 371 return 4 372 def __getitem__(self, key): 373 if key == 0: 374 return max 375 elif key == 1: 376 return tuple(range(1000000)) 377 elif key in (2, 3): 378 return {} 379 raise IndexError 380 381 f = self.partial(object) 382 self.assertRaises(TypeError, f.__setstate__, BadSequence()) 383 384@unittest.skipUnless(c_functools, 'requires the C _functools module') 385class TestPartialC(TestPartial, unittest.TestCase): 386 if c_functools: 387 partial = c_functools.partial 388 389 class AllowPickle: 390 def __enter__(self): 391 return self 392 def __exit__(self, type, value, tb): 393 return False 394 395 def test_attributes_unwritable(self): 396 # attributes should not be writable 397 p = self.partial(capture, 1, 2, a=10, b=20) 398 self.assertRaises(AttributeError, setattr, p, 'func', map) 399 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) 400 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) 401 402 p = self.partial(hex) 403 try: 404 del p.__dict__ 405 except TypeError: 406 pass 407 else: 408 self.fail('partial object allowed __dict__ to be deleted') 409 410 def test_manually_adding_non_string_keyword(self): 411 p = self.partial(capture) 412 # Adding a non-string/unicode keyword to partial kwargs 413 p.keywords[1234] = 'value' 414 r = repr(p) 415 self.assertIn('1234', r) 416 self.assertIn("'value'", r) 417 with self.assertRaises(TypeError): 418 p() 419 420 def test_keystr_replaces_value(self): 421 p = self.partial(capture) 422 423 class MutatesYourDict(object): 424 def __str__(self): 425 p.keywords[self] = ['sth2'] 426 return 'astr' 427 428 # Replacing the value during key formatting should keep the original 429 # value alive (at least long enough). 430 p.keywords[MutatesYourDict()] = ['sth'] 431 r = repr(p) 432 self.assertIn('astr', r) 433 self.assertIn("['sth']", r) 434 435 436class TestPartialPy(TestPartial, unittest.TestCase): 437 partial = py_functools.partial 438 439 class AllowPickle: 440 def __init__(self): 441 self._cm = replaced_module("functools", py_functools) 442 def __enter__(self): 443 return self._cm.__enter__() 444 def __exit__(self, type, value, tb): 445 return self._cm.__exit__(type, value, tb) 446 447if c_functools: 448 class CPartialSubclass(c_functools.partial): 449 pass 450 451class PyPartialSubclass(py_functools.partial): 452 pass 453 454@unittest.skipUnless(c_functools, 'requires the C _functools module') 455class TestPartialCSubclass(TestPartialC): 456 if c_functools: 457 partial = CPartialSubclass 458 459 # partial subclasses are not optimized for nested calls 460 test_nested_optimization = None 461 462class TestPartialPySubclass(TestPartialPy): 463 partial = PyPartialSubclass 464 465class TestPartialMethod(unittest.TestCase): 466 467 class A(object): 468 nothing = functools.partialmethod(capture) 469 positional = functools.partialmethod(capture, 1) 470 keywords = functools.partialmethod(capture, a=2) 471 both = functools.partialmethod(capture, 3, b=4) 472 spec_keywords = functools.partialmethod(capture, self=1, func=2) 473 474 nested = functools.partialmethod(positional, 5) 475 476 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7) 477 478 static = functools.partialmethod(staticmethod(capture), 8) 479 cls = functools.partialmethod(classmethod(capture), d=9) 480 481 a = A() 482 483 def test_arg_combinations(self): 484 self.assertEqual(self.a.nothing(), ((self.a,), {})) 485 self.assertEqual(self.a.nothing(5), ((self.a, 5), {})) 486 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6})) 487 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6})) 488 489 self.assertEqual(self.a.positional(), ((self.a, 1), {})) 490 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {})) 491 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6})) 492 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6})) 493 494 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2})) 495 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2})) 496 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6})) 497 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6})) 498 499 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4})) 500 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4})) 501 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6})) 502 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 503 504 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 505 506 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2})) 507 508 def test_nested(self): 509 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {})) 510 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {})) 511 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7})) 512 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 513 514 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 515 516 def test_over_partial(self): 517 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6})) 518 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6})) 519 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8})) 520 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 521 522 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 523 524 def test_bound_method_introspection(self): 525 obj = self.a 526 self.assertIs(obj.both.__self__, obj) 527 self.assertIs(obj.nested.__self__, obj) 528 self.assertIs(obj.over_partial.__self__, obj) 529 self.assertIs(obj.cls.__self__, self.A) 530 self.assertIs(self.A.cls.__self__, self.A) 531 532 def test_unbound_method_retrieval(self): 533 obj = self.A 534 self.assertFalse(hasattr(obj.both, "__self__")) 535 self.assertFalse(hasattr(obj.nested, "__self__")) 536 self.assertFalse(hasattr(obj.over_partial, "__self__")) 537 self.assertFalse(hasattr(obj.static, "__self__")) 538 self.assertFalse(hasattr(self.a.static, "__self__")) 539 540 def test_descriptors(self): 541 for obj in [self.A, self.a]: 542 with self.subTest(obj=obj): 543 self.assertEqual(obj.static(), ((8,), {})) 544 self.assertEqual(obj.static(5), ((8, 5), {})) 545 self.assertEqual(obj.static(d=8), ((8,), {'d': 8})) 546 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8})) 547 548 self.assertEqual(obj.cls(), ((self.A,), {'d': 9})) 549 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9})) 550 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9})) 551 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9})) 552 553 def test_overriding_keywords(self): 554 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3})) 555 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3})) 556 557 def test_invalid_args(self): 558 with self.assertRaises(TypeError): 559 class B(object): 560 method = functools.partialmethod(None, 1) 561 with self.assertRaises(TypeError): 562 class B: 563 method = functools.partialmethod() 564 with self.assertRaises(TypeError): 565 class B: 566 method = functools.partialmethod(func=capture, a=1) 567 568 def test_repr(self): 569 self.assertEqual(repr(vars(self.A)['both']), 570 'functools.partialmethod({}, 3, b=4)'.format(capture)) 571 572 def test_abstract(self): 573 class Abstract(abc.ABCMeta): 574 575 @abc.abstractmethod 576 def add(self, x, y): 577 pass 578 579 add5 = functools.partialmethod(add, 5) 580 581 self.assertTrue(Abstract.add.__isabstractmethod__) 582 self.assertTrue(Abstract.add5.__isabstractmethod__) 583 584 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]: 585 self.assertFalse(getattr(func, '__isabstractmethod__', False)) 586 587 def test_positional_only(self): 588 def f(a, b, /): 589 return a + b 590 591 p = functools.partial(f, 1) 592 self.assertEqual(p(2), f(1, 2)) 593 594 595class TestUpdateWrapper(unittest.TestCase): 596 597 def check_wrapper(self, wrapper, wrapped, 598 assigned=functools.WRAPPER_ASSIGNMENTS, 599 updated=functools.WRAPPER_UPDATES): 600 # Check attributes were assigned 601 for name in assigned: 602 self.assertIs(getattr(wrapper, name), getattr(wrapped, name)) 603 # Check attributes were updated 604 for name in updated: 605 wrapper_attr = getattr(wrapper, name) 606 wrapped_attr = getattr(wrapped, name) 607 for key in wrapped_attr: 608 if name == "__dict__" and key == "__wrapped__": 609 # __wrapped__ is overwritten by the update code 610 continue 611 self.assertIs(wrapped_attr[key], wrapper_attr[key]) 612 # Check __wrapped__ 613 self.assertIs(wrapper.__wrapped__, wrapped) 614 615 616 def _default_update(self): 617 def f(a:'This is a new annotation'): 618 """This is a test""" 619 pass 620 f.attr = 'This is also a test' 621 f.__wrapped__ = "This is a bald faced lie" 622 def wrapper(b:'This is the prior annotation'): 623 pass 624 functools.update_wrapper(wrapper, f) 625 return wrapper, f 626 627 def test_default_update(self): 628 wrapper, f = self._default_update() 629 self.check_wrapper(wrapper, f) 630 self.assertIs(wrapper.__wrapped__, f) 631 self.assertEqual(wrapper.__name__, 'f') 632 self.assertEqual(wrapper.__qualname__, f.__qualname__) 633 self.assertEqual(wrapper.attr, 'This is also a test') 634 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') 635 self.assertNotIn('b', wrapper.__annotations__) 636 637 @unittest.skipIf(sys.flags.optimize >= 2, 638 "Docstrings are omitted with -O2 and above") 639 def test_default_update_doc(self): 640 wrapper, f = self._default_update() 641 self.assertEqual(wrapper.__doc__, 'This is a test') 642 643 def test_no_update(self): 644 def f(): 645 """This is a test""" 646 pass 647 f.attr = 'This is also a test' 648 def wrapper(): 649 pass 650 functools.update_wrapper(wrapper, f, (), ()) 651 self.check_wrapper(wrapper, f, (), ()) 652 self.assertEqual(wrapper.__name__, 'wrapper') 653 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 654 self.assertEqual(wrapper.__doc__, None) 655 self.assertEqual(wrapper.__annotations__, {}) 656 self.assertFalse(hasattr(wrapper, 'attr')) 657 658 def test_selective_update(self): 659 def f(): 660 pass 661 f.attr = 'This is a different test' 662 f.dict_attr = dict(a=1, b=2, c=3) 663 def wrapper(): 664 pass 665 wrapper.dict_attr = {} 666 assign = ('attr',) 667 update = ('dict_attr',) 668 functools.update_wrapper(wrapper, f, assign, update) 669 self.check_wrapper(wrapper, f, assign, update) 670 self.assertEqual(wrapper.__name__, 'wrapper') 671 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 672 self.assertEqual(wrapper.__doc__, None) 673 self.assertEqual(wrapper.attr, 'This is a different test') 674 self.assertEqual(wrapper.dict_attr, f.dict_attr) 675 676 def test_missing_attributes(self): 677 def f(): 678 pass 679 def wrapper(): 680 pass 681 wrapper.dict_attr = {} 682 assign = ('attr',) 683 update = ('dict_attr',) 684 # Missing attributes on wrapped object are ignored 685 functools.update_wrapper(wrapper, f, assign, update) 686 self.assertNotIn('attr', wrapper.__dict__) 687 self.assertEqual(wrapper.dict_attr, {}) 688 # Wrapper must have expected attributes for updating 689 del wrapper.dict_attr 690 with self.assertRaises(AttributeError): 691 functools.update_wrapper(wrapper, f, assign, update) 692 wrapper.dict_attr = 1 693 with self.assertRaises(AttributeError): 694 functools.update_wrapper(wrapper, f, assign, update) 695 696 @support.requires_docstrings 697 @unittest.skipIf(sys.flags.optimize >= 2, 698 "Docstrings are omitted with -O2 and above") 699 def test_builtin_update(self): 700 # Test for bug #1576241 701 def wrapper(): 702 pass 703 functools.update_wrapper(wrapper, max) 704 self.assertEqual(wrapper.__name__, 'max') 705 self.assertTrue(wrapper.__doc__.startswith('max(')) 706 self.assertEqual(wrapper.__annotations__, {}) 707 708 709class TestWraps(TestUpdateWrapper): 710 711 def _default_update(self): 712 def f(): 713 """This is a test""" 714 pass 715 f.attr = 'This is also a test' 716 f.__wrapped__ = "This is still a bald faced lie" 717 @functools.wraps(f) 718 def wrapper(): 719 pass 720 return wrapper, f 721 722 def test_default_update(self): 723 wrapper, f = self._default_update() 724 self.check_wrapper(wrapper, f) 725 self.assertEqual(wrapper.__name__, 'f') 726 self.assertEqual(wrapper.__qualname__, f.__qualname__) 727 self.assertEqual(wrapper.attr, 'This is also a test') 728 729 @unittest.skipIf(sys.flags.optimize >= 2, 730 "Docstrings are omitted with -O2 and above") 731 def test_default_update_doc(self): 732 wrapper, _ = self._default_update() 733 self.assertEqual(wrapper.__doc__, 'This is a test') 734 735 def test_no_update(self): 736 def f(): 737 """This is a test""" 738 pass 739 f.attr = 'This is also a test' 740 @functools.wraps(f, (), ()) 741 def wrapper(): 742 pass 743 self.check_wrapper(wrapper, f, (), ()) 744 self.assertEqual(wrapper.__name__, 'wrapper') 745 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 746 self.assertEqual(wrapper.__doc__, None) 747 self.assertFalse(hasattr(wrapper, 'attr')) 748 749 def test_selective_update(self): 750 def f(): 751 pass 752 f.attr = 'This is a different test' 753 f.dict_attr = dict(a=1, b=2, c=3) 754 def add_dict_attr(f): 755 f.dict_attr = {} 756 return f 757 assign = ('attr',) 758 update = ('dict_attr',) 759 @functools.wraps(f, assign, update) 760 @add_dict_attr 761 def wrapper(): 762 pass 763 self.check_wrapper(wrapper, f, assign, update) 764 self.assertEqual(wrapper.__name__, 'wrapper') 765 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 766 self.assertEqual(wrapper.__doc__, None) 767 self.assertEqual(wrapper.attr, 'This is a different test') 768 self.assertEqual(wrapper.dict_attr, f.dict_attr) 769 770 771class TestReduce: 772 def test_reduce(self): 773 class Squares: 774 def __init__(self, max): 775 self.max = max 776 self.sofar = [] 777 778 def __len__(self): 779 return len(self.sofar) 780 781 def __getitem__(self, i): 782 if not 0 <= i < self.max: raise IndexError 783 n = len(self.sofar) 784 while n <= i: 785 self.sofar.append(n*n) 786 n += 1 787 return self.sofar[i] 788 def add(x, y): 789 return x + y 790 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc') 791 self.assertEqual( 792 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []), 793 ['a','c','d','w'] 794 ) 795 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040) 796 self.assertEqual( 797 self.reduce(lambda x, y: x*y, range(2,21), 1), 798 2432902008176640000 799 ) 800 self.assertEqual(self.reduce(add, Squares(10)), 285) 801 self.assertEqual(self.reduce(add, Squares(10), 0), 285) 802 self.assertEqual(self.reduce(add, Squares(0), 0), 0) 803 self.assertRaises(TypeError, self.reduce) 804 self.assertRaises(TypeError, self.reduce, 42, 42) 805 self.assertRaises(TypeError, self.reduce, 42, 42, 42) 806 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item 807 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item 808 self.assertRaises(TypeError, self.reduce, 42, (42, 42)) 809 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value 810 self.assertRaises(TypeError, self.reduce, add, "") 811 self.assertRaises(TypeError, self.reduce, add, ()) 812 self.assertRaises(TypeError, self.reduce, add, object()) 813 814 class TestFailingIter: 815 def __iter__(self): 816 raise RuntimeError 817 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter()) 818 819 self.assertEqual(self.reduce(add, [], None), None) 820 self.assertEqual(self.reduce(add, [], 42), 42) 821 822 class BadSeq: 823 def __getitem__(self, index): 824 raise ValueError 825 self.assertRaises(ValueError, self.reduce, 42, BadSeq()) 826 827 # Test reduce()'s use of iterators. 828 def test_iterator_usage(self): 829 class SequenceClass: 830 def __init__(self, n): 831 self.n = n 832 def __getitem__(self, i): 833 if 0 <= i < self.n: 834 return i 835 else: 836 raise IndexError 837 838 from operator import add 839 self.assertEqual(self.reduce(add, SequenceClass(5)), 10) 840 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52) 841 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0)) 842 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42) 843 self.assertEqual(self.reduce(add, SequenceClass(1)), 0) 844 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42) 845 846 d = {"one": 1, "two": 2, "three": 3} 847 self.assertEqual(self.reduce(add, d), "".join(d.keys())) 848 849 850@unittest.skipUnless(c_functools, 'requires the C _functools module') 851class TestReduceC(TestReduce, unittest.TestCase): 852 if c_functools: 853 reduce = c_functools.reduce 854 855 856class TestReducePy(TestReduce, unittest.TestCase): 857 reduce = staticmethod(py_functools.reduce) 858 859 860class TestCmpToKey: 861 862 def test_cmp_to_key(self): 863 def cmp1(x, y): 864 return (x > y) - (x < y) 865 key = self.cmp_to_key(cmp1) 866 self.assertEqual(key(3), key(3)) 867 self.assertGreater(key(3), key(1)) 868 self.assertGreaterEqual(key(3), key(3)) 869 870 def cmp2(x, y): 871 return int(x) - int(y) 872 key = self.cmp_to_key(cmp2) 873 self.assertEqual(key(4.0), key('4')) 874 self.assertLess(key(2), key('35')) 875 self.assertLessEqual(key(2), key('35')) 876 self.assertNotEqual(key(2), key('35')) 877 878 def test_cmp_to_key_arguments(self): 879 def cmp1(x, y): 880 return (x > y) - (x < y) 881 key = self.cmp_to_key(mycmp=cmp1) 882 self.assertEqual(key(obj=3), key(obj=3)) 883 self.assertGreater(key(obj=3), key(obj=1)) 884 with self.assertRaises((TypeError, AttributeError)): 885 key(3) > 1 # rhs is not a K object 886 with self.assertRaises((TypeError, AttributeError)): 887 1 < key(3) # lhs is not a K object 888 with self.assertRaises(TypeError): 889 key = self.cmp_to_key() # too few args 890 with self.assertRaises(TypeError): 891 key = self.cmp_to_key(cmp1, None) # too many args 892 key = self.cmp_to_key(cmp1) 893 with self.assertRaises(TypeError): 894 key() # too few args 895 with self.assertRaises(TypeError): 896 key(None, None) # too many args 897 898 def test_bad_cmp(self): 899 def cmp1(x, y): 900 raise ZeroDivisionError 901 key = self.cmp_to_key(cmp1) 902 with self.assertRaises(ZeroDivisionError): 903 key(3) > key(1) 904 905 class BadCmp: 906 def __lt__(self, other): 907 raise ZeroDivisionError 908 def cmp1(x, y): 909 return BadCmp() 910 with self.assertRaises(ZeroDivisionError): 911 key(3) > key(1) 912 913 def test_obj_field(self): 914 def cmp1(x, y): 915 return (x > y) - (x < y) 916 key = self.cmp_to_key(mycmp=cmp1) 917 self.assertEqual(key(50).obj, 50) 918 919 def test_sort_int(self): 920 def mycmp(x, y): 921 return y - x 922 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)), 923 [4, 3, 2, 1, 0]) 924 925 def test_sort_int_str(self): 926 def mycmp(x, y): 927 x, y = int(x), int(y) 928 return (x > y) - (x < y) 929 values = [5, '3', 7, 2, '0', '1', 4, '10', 1] 930 values = sorted(values, key=self.cmp_to_key(mycmp)) 931 self.assertEqual([int(value) for value in values], 932 [0, 1, 1, 2, 3, 4, 5, 7, 10]) 933 934 def test_hash(self): 935 def mycmp(x, y): 936 return y - x 937 key = self.cmp_to_key(mycmp) 938 k = key(10) 939 self.assertRaises(TypeError, hash, k) 940 self.assertNotIsInstance(k, collections.abc.Hashable) 941 942 943@unittest.skipUnless(c_functools, 'requires the C _functools module') 944class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): 945 if c_functools: 946 cmp_to_key = c_functools.cmp_to_key 947 948 949class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): 950 cmp_to_key = staticmethod(py_functools.cmp_to_key) 951 952 953class TestTotalOrdering(unittest.TestCase): 954 955 def test_total_ordering_lt(self): 956 @functools.total_ordering 957 class A: 958 def __init__(self, value): 959 self.value = value 960 def __lt__(self, other): 961 return self.value < other.value 962 def __eq__(self, other): 963 return self.value == other.value 964 self.assertTrue(A(1) < A(2)) 965 self.assertTrue(A(2) > A(1)) 966 self.assertTrue(A(1) <= A(2)) 967 self.assertTrue(A(2) >= A(1)) 968 self.assertTrue(A(2) <= A(2)) 969 self.assertTrue(A(2) >= A(2)) 970 self.assertFalse(A(1) > A(2)) 971 972 def test_total_ordering_le(self): 973 @functools.total_ordering 974 class A: 975 def __init__(self, value): 976 self.value = value 977 def __le__(self, other): 978 return self.value <= other.value 979 def __eq__(self, other): 980 return self.value == other.value 981 self.assertTrue(A(1) < A(2)) 982 self.assertTrue(A(2) > A(1)) 983 self.assertTrue(A(1) <= A(2)) 984 self.assertTrue(A(2) >= A(1)) 985 self.assertTrue(A(2) <= A(2)) 986 self.assertTrue(A(2) >= A(2)) 987 self.assertFalse(A(1) >= A(2)) 988 989 def test_total_ordering_gt(self): 990 @functools.total_ordering 991 class A: 992 def __init__(self, value): 993 self.value = value 994 def __gt__(self, other): 995 return self.value > other.value 996 def __eq__(self, other): 997 return self.value == other.value 998 self.assertTrue(A(1) < A(2)) 999 self.assertTrue(A(2) > A(1)) 1000 self.assertTrue(A(1) <= A(2)) 1001 self.assertTrue(A(2) >= A(1)) 1002 self.assertTrue(A(2) <= A(2)) 1003 self.assertTrue(A(2) >= A(2)) 1004 self.assertFalse(A(2) < A(1)) 1005 1006 def test_total_ordering_ge(self): 1007 @functools.total_ordering 1008 class A: 1009 def __init__(self, value): 1010 self.value = value 1011 def __ge__(self, other): 1012 return self.value >= other.value 1013 def __eq__(self, other): 1014 return self.value == other.value 1015 self.assertTrue(A(1) < A(2)) 1016 self.assertTrue(A(2) > A(1)) 1017 self.assertTrue(A(1) <= A(2)) 1018 self.assertTrue(A(2) >= A(1)) 1019 self.assertTrue(A(2) <= A(2)) 1020 self.assertTrue(A(2) >= A(2)) 1021 self.assertFalse(A(2) <= A(1)) 1022 1023 def test_total_ordering_no_overwrite(self): 1024 # new methods should not overwrite existing 1025 @functools.total_ordering 1026 class A(int): 1027 pass 1028 self.assertTrue(A(1) < A(2)) 1029 self.assertTrue(A(2) > A(1)) 1030 self.assertTrue(A(1) <= A(2)) 1031 self.assertTrue(A(2) >= A(1)) 1032 self.assertTrue(A(2) <= A(2)) 1033 self.assertTrue(A(2) >= A(2)) 1034 1035 def test_no_operations_defined(self): 1036 with self.assertRaises(ValueError): 1037 @functools.total_ordering 1038 class A: 1039 pass 1040 1041 def test_type_error_when_not_implemented(self): 1042 # bug 10042; ensure stack overflow does not occur 1043 # when decorated types return NotImplemented 1044 @functools.total_ordering 1045 class ImplementsLessThan: 1046 def __init__(self, value): 1047 self.value = value 1048 def __eq__(self, other): 1049 if isinstance(other, ImplementsLessThan): 1050 return self.value == other.value 1051 return False 1052 def __lt__(self, other): 1053 if isinstance(other, ImplementsLessThan): 1054 return self.value < other.value 1055 return NotImplemented 1056 1057 @functools.total_ordering 1058 class ImplementsGreaterThan: 1059 def __init__(self, value): 1060 self.value = value 1061 def __eq__(self, other): 1062 if isinstance(other, ImplementsGreaterThan): 1063 return self.value == other.value 1064 return False 1065 def __gt__(self, other): 1066 if isinstance(other, ImplementsGreaterThan): 1067 return self.value > other.value 1068 return NotImplemented 1069 1070 @functools.total_ordering 1071 class ImplementsLessThanEqualTo: 1072 def __init__(self, value): 1073 self.value = value 1074 def __eq__(self, other): 1075 if isinstance(other, ImplementsLessThanEqualTo): 1076 return self.value == other.value 1077 return False 1078 def __le__(self, other): 1079 if isinstance(other, ImplementsLessThanEqualTo): 1080 return self.value <= other.value 1081 return NotImplemented 1082 1083 @functools.total_ordering 1084 class ImplementsGreaterThanEqualTo: 1085 def __init__(self, value): 1086 self.value = value 1087 def __eq__(self, other): 1088 if isinstance(other, ImplementsGreaterThanEqualTo): 1089 return self.value == other.value 1090 return False 1091 def __ge__(self, other): 1092 if isinstance(other, ImplementsGreaterThanEqualTo): 1093 return self.value >= other.value 1094 return NotImplemented 1095 1096 @functools.total_ordering 1097 class ComparatorNotImplemented: 1098 def __init__(self, value): 1099 self.value = value 1100 def __eq__(self, other): 1101 if isinstance(other, ComparatorNotImplemented): 1102 return self.value == other.value 1103 return False 1104 def __lt__(self, other): 1105 return NotImplemented 1106 1107 with self.subTest("LT < 1"), self.assertRaises(TypeError): 1108 ImplementsLessThan(-1) < 1 1109 1110 with self.subTest("LT < LE"), self.assertRaises(TypeError): 1111 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) 1112 1113 with self.subTest("LT < GT"), self.assertRaises(TypeError): 1114 ImplementsLessThan(1) < ImplementsGreaterThan(1) 1115 1116 with self.subTest("LE <= LT"), self.assertRaises(TypeError): 1117 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) 1118 1119 with self.subTest("LE <= GE"), self.assertRaises(TypeError): 1120 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) 1121 1122 with self.subTest("GT > GE"), self.assertRaises(TypeError): 1123 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) 1124 1125 with self.subTest("GT > LT"), self.assertRaises(TypeError): 1126 ImplementsGreaterThan(5) > ImplementsLessThan(5) 1127 1128 with self.subTest("GE >= GT"), self.assertRaises(TypeError): 1129 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) 1130 1131 with self.subTest("GE >= LE"), self.assertRaises(TypeError): 1132 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) 1133 1134 with self.subTest("GE when equal"): 1135 a = ComparatorNotImplemented(8) 1136 b = ComparatorNotImplemented(8) 1137 self.assertEqual(a, b) 1138 with self.assertRaises(TypeError): 1139 a >= b 1140 1141 with self.subTest("LE when equal"): 1142 a = ComparatorNotImplemented(9) 1143 b = ComparatorNotImplemented(9) 1144 self.assertEqual(a, b) 1145 with self.assertRaises(TypeError): 1146 a <= b 1147 1148 def test_pickle(self): 1149 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1150 for name in '__lt__', '__gt__', '__le__', '__ge__': 1151 with self.subTest(method=name, proto=proto): 1152 method = getattr(Orderable_LT, name) 1153 method_copy = pickle.loads(pickle.dumps(method, proto)) 1154 self.assertIs(method_copy, method) 1155 1156@functools.total_ordering 1157class Orderable_LT: 1158 def __init__(self, value): 1159 self.value = value 1160 def __lt__(self, other): 1161 return self.value < other.value 1162 def __eq__(self, other): 1163 return self.value == other.value 1164 1165 1166class TestCache: 1167 # This tests that the pass-through is working as designed. 1168 # The underlying functionality is tested in TestLRU. 1169 1170 def test_cache(self): 1171 @self.module.cache 1172 def fib(n): 1173 if n < 2: 1174 return n 1175 return fib(n-1) + fib(n-2) 1176 self.assertEqual([fib(n) for n in range(16)], 1177 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1178 self.assertEqual(fib.cache_info(), 1179 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1180 fib.cache_clear() 1181 self.assertEqual(fib.cache_info(), 1182 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1183 1184 1185class TestLRU: 1186 1187 def test_lru(self): 1188 def orig(x, y): 1189 return 3 * x + y 1190 f = self.module.lru_cache(maxsize=20)(orig) 1191 hits, misses, maxsize, currsize = f.cache_info() 1192 self.assertEqual(maxsize, 20) 1193 self.assertEqual(currsize, 0) 1194 self.assertEqual(hits, 0) 1195 self.assertEqual(misses, 0) 1196 1197 domain = range(5) 1198 for i in range(1000): 1199 x, y = choice(domain), choice(domain) 1200 actual = f(x, y) 1201 expected = orig(x, y) 1202 self.assertEqual(actual, expected) 1203 hits, misses, maxsize, currsize = f.cache_info() 1204 self.assertTrue(hits > misses) 1205 self.assertEqual(hits + misses, 1000) 1206 self.assertEqual(currsize, 20) 1207 1208 f.cache_clear() # test clearing 1209 hits, misses, maxsize, currsize = f.cache_info() 1210 self.assertEqual(hits, 0) 1211 self.assertEqual(misses, 0) 1212 self.assertEqual(currsize, 0) 1213 f(x, y) 1214 hits, misses, maxsize, currsize = f.cache_info() 1215 self.assertEqual(hits, 0) 1216 self.assertEqual(misses, 1) 1217 self.assertEqual(currsize, 1) 1218 1219 # Test bypassing the cache 1220 self.assertIs(f.__wrapped__, orig) 1221 f.__wrapped__(x, y) 1222 hits, misses, maxsize, currsize = f.cache_info() 1223 self.assertEqual(hits, 0) 1224 self.assertEqual(misses, 1) 1225 self.assertEqual(currsize, 1) 1226 1227 # test size zero (which means "never-cache") 1228 @self.module.lru_cache(0) 1229 def f(): 1230 nonlocal f_cnt 1231 f_cnt += 1 1232 return 20 1233 self.assertEqual(f.cache_info().maxsize, 0) 1234 f_cnt = 0 1235 for i in range(5): 1236 self.assertEqual(f(), 20) 1237 self.assertEqual(f_cnt, 5) 1238 hits, misses, maxsize, currsize = f.cache_info() 1239 self.assertEqual(hits, 0) 1240 self.assertEqual(misses, 5) 1241 self.assertEqual(currsize, 0) 1242 1243 # test size one 1244 @self.module.lru_cache(1) 1245 def f(): 1246 nonlocal f_cnt 1247 f_cnt += 1 1248 return 20 1249 self.assertEqual(f.cache_info().maxsize, 1) 1250 f_cnt = 0 1251 for i in range(5): 1252 self.assertEqual(f(), 20) 1253 self.assertEqual(f_cnt, 1) 1254 hits, misses, maxsize, currsize = f.cache_info() 1255 self.assertEqual(hits, 4) 1256 self.assertEqual(misses, 1) 1257 self.assertEqual(currsize, 1) 1258 1259 # test size two 1260 @self.module.lru_cache(2) 1261 def f(x): 1262 nonlocal f_cnt 1263 f_cnt += 1 1264 return x*10 1265 self.assertEqual(f.cache_info().maxsize, 2) 1266 f_cnt = 0 1267 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7: 1268 # * * * * 1269 self.assertEqual(f(x), x*10) 1270 self.assertEqual(f_cnt, 4) 1271 hits, misses, maxsize, currsize = f.cache_info() 1272 self.assertEqual(hits, 12) 1273 self.assertEqual(misses, 4) 1274 self.assertEqual(currsize, 2) 1275 1276 def test_lru_no_args(self): 1277 @self.module.lru_cache 1278 def square(x): 1279 return x ** 2 1280 1281 self.assertEqual(list(map(square, [10, 20, 10])), 1282 [100, 400, 100]) 1283 self.assertEqual(square.cache_info().hits, 1) 1284 self.assertEqual(square.cache_info().misses, 2) 1285 self.assertEqual(square.cache_info().maxsize, 128) 1286 self.assertEqual(square.cache_info().currsize, 2) 1287 1288 def test_lru_bug_35780(self): 1289 # C version of the lru_cache was not checking to see if 1290 # the user function call has already modified the cache 1291 # (this arises in recursive calls and in multi-threading). 1292 # This cause the cache to have orphan links not referenced 1293 # by the cache dictionary. 1294 1295 once = True # Modified by f(x) below 1296 1297 @self.module.lru_cache(maxsize=10) 1298 def f(x): 1299 nonlocal once 1300 rv = f'.{x}.' 1301 if x == 20 and once: 1302 once = False 1303 rv = f(x) 1304 return rv 1305 1306 # Fill the cache 1307 for x in range(15): 1308 self.assertEqual(f(x), f'.{x}.') 1309 self.assertEqual(f.cache_info().currsize, 10) 1310 1311 # Make a recursive call and make sure the cache remains full 1312 self.assertEqual(f(20), '.20.') 1313 self.assertEqual(f.cache_info().currsize, 10) 1314 1315 def test_lru_bug_36650(self): 1316 # C version of lru_cache was treating a call with an empty **kwargs 1317 # dictionary as being distinct from a call with no keywords at all. 1318 # This did not result in an incorrect answer, but it did trigger 1319 # an unexpected cache miss. 1320 1321 @self.module.lru_cache() 1322 def f(x): 1323 pass 1324 1325 f(0) 1326 f(0, **{}) 1327 self.assertEqual(f.cache_info().hits, 1) 1328 1329 def test_lru_hash_only_once(self): 1330 # To protect against weird reentrancy bugs and to improve 1331 # efficiency when faced with slow __hash__ methods, the 1332 # LRU cache guarantees that it will only call __hash__ 1333 # only once per use as an argument to the cached function. 1334 1335 @self.module.lru_cache(maxsize=1) 1336 def f(x, y): 1337 return x * 3 + y 1338 1339 # Simulate the integer 5 1340 mock_int = unittest.mock.Mock() 1341 mock_int.__mul__ = unittest.mock.Mock(return_value=15) 1342 mock_int.__hash__ = unittest.mock.Mock(return_value=999) 1343 1344 # Add to cache: One use as an argument gives one call 1345 self.assertEqual(f(mock_int, 1), 16) 1346 self.assertEqual(mock_int.__hash__.call_count, 1) 1347 self.assertEqual(f.cache_info(), (0, 1, 1, 1)) 1348 1349 # Cache hit: One use as an argument gives one additional call 1350 self.assertEqual(f(mock_int, 1), 16) 1351 self.assertEqual(mock_int.__hash__.call_count, 2) 1352 self.assertEqual(f.cache_info(), (1, 1, 1, 1)) 1353 1354 # Cache eviction: No use as an argument gives no additional call 1355 self.assertEqual(f(6, 2), 20) 1356 self.assertEqual(mock_int.__hash__.call_count, 2) 1357 self.assertEqual(f.cache_info(), (1, 2, 1, 1)) 1358 1359 # Cache miss: One use as an argument gives one additional call 1360 self.assertEqual(f(mock_int, 1), 16) 1361 self.assertEqual(mock_int.__hash__.call_count, 3) 1362 self.assertEqual(f.cache_info(), (1, 3, 1, 1)) 1363 1364 def test_lru_reentrancy_with_len(self): 1365 # Test to make sure the LRU cache code isn't thrown-off by 1366 # caching the built-in len() function. Since len() can be 1367 # cached, we shouldn't use it inside the lru code itself. 1368 old_len = builtins.len 1369 try: 1370 builtins.len = self.module.lru_cache(4)(len) 1371 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]: 1372 self.assertEqual(len('abcdefghijklmn'[:i]), i) 1373 finally: 1374 builtins.len = old_len 1375 1376 def test_lru_star_arg_handling(self): 1377 # Test regression that arose in ea064ff3c10f 1378 @functools.lru_cache() 1379 def f(*args): 1380 return args 1381 1382 self.assertEqual(f(1, 2), (1, 2)) 1383 self.assertEqual(f((1, 2)), ((1, 2),)) 1384 1385 def test_lru_type_error(self): 1386 # Regression test for issue #28653. 1387 # lru_cache was leaking when one of the arguments 1388 # wasn't cacheable. 1389 1390 @functools.lru_cache(maxsize=None) 1391 def infinite_cache(o): 1392 pass 1393 1394 @functools.lru_cache(maxsize=10) 1395 def limited_cache(o): 1396 pass 1397 1398 with self.assertRaises(TypeError): 1399 infinite_cache([]) 1400 1401 with self.assertRaises(TypeError): 1402 limited_cache([]) 1403 1404 def test_lru_with_maxsize_none(self): 1405 @self.module.lru_cache(maxsize=None) 1406 def fib(n): 1407 if n < 2: 1408 return n 1409 return fib(n-1) + fib(n-2) 1410 self.assertEqual([fib(n) for n in range(16)], 1411 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1412 self.assertEqual(fib.cache_info(), 1413 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1414 fib.cache_clear() 1415 self.assertEqual(fib.cache_info(), 1416 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1417 1418 def test_lru_with_maxsize_negative(self): 1419 @self.module.lru_cache(maxsize=-10) 1420 def eq(n): 1421 return n 1422 for i in (0, 1): 1423 self.assertEqual([eq(n) for n in range(150)], list(range(150))) 1424 self.assertEqual(eq.cache_info(), 1425 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0)) 1426 1427 def test_lru_with_exceptions(self): 1428 # Verify that user_function exceptions get passed through without 1429 # creating a hard-to-read chained exception. 1430 # http://bugs.python.org/issue13177 1431 for maxsize in (None, 128): 1432 @self.module.lru_cache(maxsize) 1433 def func(i): 1434 return 'abc'[i] 1435 self.assertEqual(func(0), 'a') 1436 with self.assertRaises(IndexError) as cm: 1437 func(15) 1438 self.assertIsNone(cm.exception.__context__) 1439 # Verify that the previous exception did not result in a cached entry 1440 with self.assertRaises(IndexError): 1441 func(15) 1442 1443 def test_lru_with_types(self): 1444 for maxsize in (None, 128): 1445 @self.module.lru_cache(maxsize=maxsize, typed=True) 1446 def square(x): 1447 return x * x 1448 self.assertEqual(square(3), 9) 1449 self.assertEqual(type(square(3)), type(9)) 1450 self.assertEqual(square(3.0), 9.0) 1451 self.assertEqual(type(square(3.0)), type(9.0)) 1452 self.assertEqual(square(x=3), 9) 1453 self.assertEqual(type(square(x=3)), type(9)) 1454 self.assertEqual(square(x=3.0), 9.0) 1455 self.assertEqual(type(square(x=3.0)), type(9.0)) 1456 self.assertEqual(square.cache_info().hits, 4) 1457 self.assertEqual(square.cache_info().misses, 4) 1458 1459 def test_lru_with_keyword_args(self): 1460 @self.module.lru_cache() 1461 def fib(n): 1462 if n < 2: 1463 return n 1464 return fib(n=n-1) + fib(n=n-2) 1465 self.assertEqual( 1466 [fib(n=number) for number in range(16)], 1467 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] 1468 ) 1469 self.assertEqual(fib.cache_info(), 1470 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) 1471 fib.cache_clear() 1472 self.assertEqual(fib.cache_info(), 1473 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) 1474 1475 def test_lru_with_keyword_args_maxsize_none(self): 1476 @self.module.lru_cache(maxsize=None) 1477 def fib(n): 1478 if n < 2: 1479 return n 1480 return fib(n=n-1) + fib(n=n-2) 1481 self.assertEqual([fib(n=number) for number in range(16)], 1482 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1483 self.assertEqual(fib.cache_info(), 1484 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1485 fib.cache_clear() 1486 self.assertEqual(fib.cache_info(), 1487 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1488 1489 def test_kwargs_order(self): 1490 # PEP 468: Preserving Keyword Argument Order 1491 @self.module.lru_cache(maxsize=10) 1492 def f(**kwargs): 1493 return list(kwargs.items()) 1494 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)]) 1495 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)]) 1496 self.assertEqual(f.cache_info(), 1497 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2)) 1498 1499 def test_lru_cache_decoration(self): 1500 def f(zomg: 'zomg_annotation'): 1501 """f doc string""" 1502 return 42 1503 g = self.module.lru_cache()(f) 1504 for attr in self.module.WRAPPER_ASSIGNMENTS: 1505 self.assertEqual(getattr(g, attr), getattr(f, attr)) 1506 1507 def test_lru_cache_threaded(self): 1508 n, m = 5, 11 1509 def orig(x, y): 1510 return 3 * x + y 1511 f = self.module.lru_cache(maxsize=n*m)(orig) 1512 hits, misses, maxsize, currsize = f.cache_info() 1513 self.assertEqual(currsize, 0) 1514 1515 start = threading.Event() 1516 def full(k): 1517 start.wait(10) 1518 for _ in range(m): 1519 self.assertEqual(f(k, 0), orig(k, 0)) 1520 1521 def clear(): 1522 start.wait(10) 1523 for _ in range(2*m): 1524 f.cache_clear() 1525 1526 orig_si = sys.getswitchinterval() 1527 support.setswitchinterval(1e-6) 1528 try: 1529 # create n threads in order to fill cache 1530 threads = [threading.Thread(target=full, args=[k]) 1531 for k in range(n)] 1532 with support.start_threads(threads): 1533 start.set() 1534 1535 hits, misses, maxsize, currsize = f.cache_info() 1536 if self.module is py_functools: 1537 # XXX: Why can be not equal? 1538 self.assertLessEqual(misses, n) 1539 self.assertLessEqual(hits, m*n - misses) 1540 else: 1541 self.assertEqual(misses, n) 1542 self.assertEqual(hits, m*n - misses) 1543 self.assertEqual(currsize, n) 1544 1545 # create n threads in order to fill cache and 1 to clear it 1546 threads = [threading.Thread(target=clear)] 1547 threads += [threading.Thread(target=full, args=[k]) 1548 for k in range(n)] 1549 start.clear() 1550 with support.start_threads(threads): 1551 start.set() 1552 finally: 1553 sys.setswitchinterval(orig_si) 1554 1555 def test_lru_cache_threaded2(self): 1556 # Simultaneous call with the same arguments 1557 n, m = 5, 7 1558 start = threading.Barrier(n+1) 1559 pause = threading.Barrier(n+1) 1560 stop = threading.Barrier(n+1) 1561 @self.module.lru_cache(maxsize=m*n) 1562 def f(x): 1563 pause.wait(10) 1564 return 3 * x 1565 self.assertEqual(f.cache_info(), (0, 0, m*n, 0)) 1566 def test(): 1567 for i in range(m): 1568 start.wait(10) 1569 self.assertEqual(f(i), 3 * i) 1570 stop.wait(10) 1571 threads = [threading.Thread(target=test) for k in range(n)] 1572 with support.start_threads(threads): 1573 for i in range(m): 1574 start.wait(10) 1575 stop.reset() 1576 pause.wait(10) 1577 start.reset() 1578 stop.wait(10) 1579 pause.reset() 1580 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) 1581 1582 def test_lru_cache_threaded3(self): 1583 @self.module.lru_cache(maxsize=2) 1584 def f(x): 1585 time.sleep(.01) 1586 return 3 * x 1587 def test(i, x): 1588 with self.subTest(thread=i): 1589 self.assertEqual(f(x), 3 * x, i) 1590 threads = [threading.Thread(target=test, args=(i, v)) 1591 for i, v in enumerate([1, 2, 2, 3, 2])] 1592 with support.start_threads(threads): 1593 pass 1594 1595 def test_need_for_rlock(self): 1596 # This will deadlock on an LRU cache that uses a regular lock 1597 1598 @self.module.lru_cache(maxsize=10) 1599 def test_func(x): 1600 'Used to demonstrate a reentrant lru_cache call within a single thread' 1601 return x 1602 1603 class DoubleEq: 1604 'Demonstrate a reentrant lru_cache call within a single thread' 1605 def __init__(self, x): 1606 self.x = x 1607 def __hash__(self): 1608 return self.x 1609 def __eq__(self, other): 1610 if self.x == 2: 1611 test_func(DoubleEq(1)) 1612 return self.x == other.x 1613 1614 test_func(DoubleEq(1)) # Load the cache 1615 test_func(DoubleEq(2)) # Load the cache 1616 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call 1617 DoubleEq(2)) # Verify the correct return value 1618 1619 def test_lru_method(self): 1620 class X(int): 1621 f_cnt = 0 1622 @self.module.lru_cache(2) 1623 def f(self, x): 1624 self.f_cnt += 1 1625 return x*10+self 1626 a = X(5) 1627 b = X(5) 1628 c = X(7) 1629 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0)) 1630 1631 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3: 1632 self.assertEqual(a.f(x), x*10 + 5) 1633 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0)) 1634 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2)) 1635 1636 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2: 1637 self.assertEqual(b.f(x), x*10 + 5) 1638 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0)) 1639 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2)) 1640 1641 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1: 1642 self.assertEqual(c.f(x), x*10 + 7) 1643 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5)) 1644 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2)) 1645 1646 self.assertEqual(a.f.cache_info(), X.f.cache_info()) 1647 self.assertEqual(b.f.cache_info(), X.f.cache_info()) 1648 self.assertEqual(c.f.cache_info(), X.f.cache_info()) 1649 1650 def test_pickle(self): 1651 cls = self.__class__ 1652 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth: 1653 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1654 with self.subTest(proto=proto, func=f): 1655 f_copy = pickle.loads(pickle.dumps(f, proto)) 1656 self.assertIs(f_copy, f) 1657 1658 def test_copy(self): 1659 cls = self.__class__ 1660 def orig(x, y): 1661 return 3 * x + y 1662 part = self.module.partial(orig, 2) 1663 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1664 self.module.lru_cache(2)(part)) 1665 for f in funcs: 1666 with self.subTest(func=f): 1667 f_copy = copy.copy(f) 1668 self.assertIs(f_copy, f) 1669 1670 def test_deepcopy(self): 1671 cls = self.__class__ 1672 def orig(x, y): 1673 return 3 * x + y 1674 part = self.module.partial(orig, 2) 1675 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1676 self.module.lru_cache(2)(part)) 1677 for f in funcs: 1678 with self.subTest(func=f): 1679 f_copy = copy.deepcopy(f) 1680 self.assertIs(f_copy, f) 1681 1682 def test_lru_cache_parameters(self): 1683 @self.module.lru_cache(maxsize=2) 1684 def f(): 1685 return 1 1686 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False}) 1687 1688 @self.module.lru_cache(maxsize=1000, typed=True) 1689 def f(): 1690 return 1 1691 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True}) 1692 1693 def test_lru_cache_weakrefable(self): 1694 @self.module.lru_cache 1695 def test_function(x): 1696 return x 1697 1698 class A: 1699 @self.module.lru_cache 1700 def test_method(self, x): 1701 return (self, x) 1702 1703 @staticmethod 1704 @self.module.lru_cache 1705 def test_staticmethod(x): 1706 return (self, x) 1707 1708 refs = [weakref.ref(test_function), 1709 weakref.ref(A.test_method), 1710 weakref.ref(A.test_staticmethod)] 1711 1712 for ref in refs: 1713 self.assertIsNotNone(ref()) 1714 1715 del A 1716 del test_function 1717 gc.collect() 1718 1719 for ref in refs: 1720 self.assertIsNone(ref()) 1721 1722 1723@py_functools.lru_cache() 1724def py_cached_func(x, y): 1725 return 3 * x + y 1726 1727@c_functools.lru_cache() 1728def c_cached_func(x, y): 1729 return 3 * x + y 1730 1731 1732class TestLRUPy(TestLRU, unittest.TestCase): 1733 module = py_functools 1734 cached_func = py_cached_func, 1735 1736 @module.lru_cache() 1737 def cached_meth(self, x, y): 1738 return 3 * x + y 1739 1740 @staticmethod 1741 @module.lru_cache() 1742 def cached_staticmeth(x, y): 1743 return 3 * x + y 1744 1745 1746class TestLRUC(TestLRU, unittest.TestCase): 1747 module = c_functools 1748 cached_func = c_cached_func, 1749 1750 @module.lru_cache() 1751 def cached_meth(self, x, y): 1752 return 3 * x + y 1753 1754 @staticmethod 1755 @module.lru_cache() 1756 def cached_staticmeth(x, y): 1757 return 3 * x + y 1758 1759 1760class TestSingleDispatch(unittest.TestCase): 1761 def test_simple_overloads(self): 1762 @functools.singledispatch 1763 def g(obj): 1764 return "base" 1765 def g_int(i): 1766 return "integer" 1767 g.register(int, g_int) 1768 self.assertEqual(g("str"), "base") 1769 self.assertEqual(g(1), "integer") 1770 self.assertEqual(g([1,2,3]), "base") 1771 1772 def test_mro(self): 1773 @functools.singledispatch 1774 def g(obj): 1775 return "base" 1776 class A: 1777 pass 1778 class C(A): 1779 pass 1780 class B(A): 1781 pass 1782 class D(C, B): 1783 pass 1784 def g_A(a): 1785 return "A" 1786 def g_B(b): 1787 return "B" 1788 g.register(A, g_A) 1789 g.register(B, g_B) 1790 self.assertEqual(g(A()), "A") 1791 self.assertEqual(g(B()), "B") 1792 self.assertEqual(g(C()), "A") 1793 self.assertEqual(g(D()), "B") 1794 1795 def test_register_decorator(self): 1796 @functools.singledispatch 1797 def g(obj): 1798 return "base" 1799 @g.register(int) 1800 def g_int(i): 1801 return "int %s" % (i,) 1802 self.assertEqual(g(""), "base") 1803 self.assertEqual(g(12), "int 12") 1804 self.assertIs(g.dispatch(int), g_int) 1805 self.assertIs(g.dispatch(object), g.dispatch(str)) 1806 # Note: in the assert above this is not g. 1807 # @singledispatch returns the wrapper. 1808 1809 def test_wrapping_attributes(self): 1810 @functools.singledispatch 1811 def g(obj): 1812 "Simple test" 1813 return "Test" 1814 self.assertEqual(g.__name__, "g") 1815 if sys.flags.optimize < 2: 1816 self.assertEqual(g.__doc__, "Simple test") 1817 1818 @unittest.skipUnless(decimal, 'requires _decimal') 1819 @support.cpython_only 1820 def test_c_classes(self): 1821 @functools.singledispatch 1822 def g(obj): 1823 return "base" 1824 @g.register(decimal.DecimalException) 1825 def _(obj): 1826 return obj.args 1827 subn = decimal.Subnormal("Exponent < Emin") 1828 rnd = decimal.Rounded("Number got rounded") 1829 self.assertEqual(g(subn), ("Exponent < Emin",)) 1830 self.assertEqual(g(rnd), ("Number got rounded",)) 1831 @g.register(decimal.Subnormal) 1832 def _(obj): 1833 return "Too small to care." 1834 self.assertEqual(g(subn), "Too small to care.") 1835 self.assertEqual(g(rnd), ("Number got rounded",)) 1836 1837 def test_compose_mro(self): 1838 # None of the examples in this test depend on haystack ordering. 1839 c = collections.abc 1840 mro = functools._compose_mro 1841 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] 1842 for haystack in permutations(bases): 1843 m = mro(dict, haystack) 1844 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, 1845 c.Collection, c.Sized, c.Iterable, 1846 c.Container, object]) 1847 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict] 1848 for haystack in permutations(bases): 1849 m = mro(collections.ChainMap, haystack) 1850 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping, 1851 c.Collection, c.Sized, c.Iterable, 1852 c.Container, object]) 1853 1854 # If there's a generic function with implementations registered for 1855 # both Sized and Container, passing a defaultdict to it results in an 1856 # ambiguous dispatch which will cause a RuntimeError (see 1857 # test_mro_conflicts). 1858 bases = [c.Container, c.Sized, str] 1859 for haystack in permutations(bases): 1860 m = mro(collections.defaultdict, [c.Sized, c.Container, str]) 1861 self.assertEqual(m, [collections.defaultdict, dict, c.Sized, 1862 c.Container, object]) 1863 1864 # MutableSequence below is registered directly on D. In other words, it 1865 # precedes MutableMapping which means single dispatch will always 1866 # choose MutableSequence here. 1867 class D(collections.defaultdict): 1868 pass 1869 c.MutableSequence.register(D) 1870 bases = [c.MutableSequence, c.MutableMapping] 1871 for haystack in permutations(bases): 1872 m = mro(D, bases) 1873 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible, 1874 collections.defaultdict, dict, c.MutableMapping, c.Mapping, 1875 c.Collection, c.Sized, c.Iterable, c.Container, 1876 object]) 1877 1878 # Container and Callable are registered on different base classes and 1879 # a generic function supporting both should always pick the Callable 1880 # implementation if a C instance is passed. 1881 class C(collections.defaultdict): 1882 def __call__(self): 1883 pass 1884 bases = [c.Sized, c.Callable, c.Container, c.Mapping] 1885 for haystack in permutations(bases): 1886 m = mro(C, haystack) 1887 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping, 1888 c.Collection, c.Sized, c.Iterable, 1889 c.Container, object]) 1890 1891 def test_register_abc(self): 1892 c = collections.abc 1893 d = {"a": "b"} 1894 l = [1, 2, 3] 1895 s = {object(), None} 1896 f = frozenset(s) 1897 t = (1, 2, 3) 1898 @functools.singledispatch 1899 def g(obj): 1900 return "base" 1901 self.assertEqual(g(d), "base") 1902 self.assertEqual(g(l), "base") 1903 self.assertEqual(g(s), "base") 1904 self.assertEqual(g(f), "base") 1905 self.assertEqual(g(t), "base") 1906 g.register(c.Sized, lambda obj: "sized") 1907 self.assertEqual(g(d), "sized") 1908 self.assertEqual(g(l), "sized") 1909 self.assertEqual(g(s), "sized") 1910 self.assertEqual(g(f), "sized") 1911 self.assertEqual(g(t), "sized") 1912 g.register(c.MutableMapping, lambda obj: "mutablemapping") 1913 self.assertEqual(g(d), "mutablemapping") 1914 self.assertEqual(g(l), "sized") 1915 self.assertEqual(g(s), "sized") 1916 self.assertEqual(g(f), "sized") 1917 self.assertEqual(g(t), "sized") 1918 g.register(collections.ChainMap, lambda obj: "chainmap") 1919 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered 1920 self.assertEqual(g(l), "sized") 1921 self.assertEqual(g(s), "sized") 1922 self.assertEqual(g(f), "sized") 1923 self.assertEqual(g(t), "sized") 1924 g.register(c.MutableSequence, lambda obj: "mutablesequence") 1925 self.assertEqual(g(d), "mutablemapping") 1926 self.assertEqual(g(l), "mutablesequence") 1927 self.assertEqual(g(s), "sized") 1928 self.assertEqual(g(f), "sized") 1929 self.assertEqual(g(t), "sized") 1930 g.register(c.MutableSet, lambda obj: "mutableset") 1931 self.assertEqual(g(d), "mutablemapping") 1932 self.assertEqual(g(l), "mutablesequence") 1933 self.assertEqual(g(s), "mutableset") 1934 self.assertEqual(g(f), "sized") 1935 self.assertEqual(g(t), "sized") 1936 g.register(c.Mapping, lambda obj: "mapping") 1937 self.assertEqual(g(d), "mutablemapping") # not specific enough 1938 self.assertEqual(g(l), "mutablesequence") 1939 self.assertEqual(g(s), "mutableset") 1940 self.assertEqual(g(f), "sized") 1941 self.assertEqual(g(t), "sized") 1942 g.register(c.Sequence, lambda obj: "sequence") 1943 self.assertEqual(g(d), "mutablemapping") 1944 self.assertEqual(g(l), "mutablesequence") 1945 self.assertEqual(g(s), "mutableset") 1946 self.assertEqual(g(f), "sized") 1947 self.assertEqual(g(t), "sequence") 1948 g.register(c.Set, lambda obj: "set") 1949 self.assertEqual(g(d), "mutablemapping") 1950 self.assertEqual(g(l), "mutablesequence") 1951 self.assertEqual(g(s), "mutableset") 1952 self.assertEqual(g(f), "set") 1953 self.assertEqual(g(t), "sequence") 1954 g.register(dict, lambda obj: "dict") 1955 self.assertEqual(g(d), "dict") 1956 self.assertEqual(g(l), "mutablesequence") 1957 self.assertEqual(g(s), "mutableset") 1958 self.assertEqual(g(f), "set") 1959 self.assertEqual(g(t), "sequence") 1960 g.register(list, lambda obj: "list") 1961 self.assertEqual(g(d), "dict") 1962 self.assertEqual(g(l), "list") 1963 self.assertEqual(g(s), "mutableset") 1964 self.assertEqual(g(f), "set") 1965 self.assertEqual(g(t), "sequence") 1966 g.register(set, lambda obj: "concrete-set") 1967 self.assertEqual(g(d), "dict") 1968 self.assertEqual(g(l), "list") 1969 self.assertEqual(g(s), "concrete-set") 1970 self.assertEqual(g(f), "set") 1971 self.assertEqual(g(t), "sequence") 1972 g.register(frozenset, lambda obj: "frozen-set") 1973 self.assertEqual(g(d), "dict") 1974 self.assertEqual(g(l), "list") 1975 self.assertEqual(g(s), "concrete-set") 1976 self.assertEqual(g(f), "frozen-set") 1977 self.assertEqual(g(t), "sequence") 1978 g.register(tuple, lambda obj: "tuple") 1979 self.assertEqual(g(d), "dict") 1980 self.assertEqual(g(l), "list") 1981 self.assertEqual(g(s), "concrete-set") 1982 self.assertEqual(g(f), "frozen-set") 1983 self.assertEqual(g(t), "tuple") 1984 1985 def test_c3_abc(self): 1986 c = collections.abc 1987 mro = functools._c3_mro 1988 class A(object): 1989 pass 1990 class B(A): 1991 def __len__(self): 1992 return 0 # implies Sized 1993 @c.Container.register 1994 class C(object): 1995 pass 1996 class D(object): 1997 pass # unrelated 1998 class X(D, C, B): 1999 def __call__(self): 2000 pass # implies Callable 2001 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] 2002 for abcs in permutations([c.Sized, c.Callable, c.Container]): 2003 self.assertEqual(mro(X, abcs=abcs), expected) 2004 # unrelated ABCs don't appear in the resulting MRO 2005 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] 2006 self.assertEqual(mro(X, abcs=many_abcs), expected) 2007 2008 def test_false_meta(self): 2009 # see issue23572 2010 class MetaA(type): 2011 def __len__(self): 2012 return 0 2013 class A(metaclass=MetaA): 2014 pass 2015 class AA(A): 2016 pass 2017 @functools.singledispatch 2018 def fun(a): 2019 return 'base A' 2020 @fun.register(A) 2021 def _(a): 2022 return 'fun A' 2023 aa = AA() 2024 self.assertEqual(fun(aa), 'fun A') 2025 2026 def test_mro_conflicts(self): 2027 c = collections.abc 2028 @functools.singledispatch 2029 def g(arg): 2030 return "base" 2031 class O(c.Sized): 2032 def __len__(self): 2033 return 0 2034 o = O() 2035 self.assertEqual(g(o), "base") 2036 g.register(c.Iterable, lambda arg: "iterable") 2037 g.register(c.Container, lambda arg: "container") 2038 g.register(c.Sized, lambda arg: "sized") 2039 g.register(c.Set, lambda arg: "set") 2040 self.assertEqual(g(o), "sized") 2041 c.Iterable.register(O) 2042 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ 2043 c.Container.register(O) 2044 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ 2045 c.Set.register(O) 2046 self.assertEqual(g(o), "set") # because c.Set is a subclass of 2047 # c.Sized and c.Container 2048 class P: 2049 pass 2050 p = P() 2051 self.assertEqual(g(p), "base") 2052 c.Iterable.register(P) 2053 self.assertEqual(g(p), "iterable") 2054 c.Container.register(P) 2055 with self.assertRaises(RuntimeError) as re_one: 2056 g(p) 2057 self.assertIn( 2058 str(re_one.exception), 2059 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2060 "or <class 'collections.abc.Iterable'>"), 2061 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> " 2062 "or <class 'collections.abc.Container'>")), 2063 ) 2064 class Q(c.Sized): 2065 def __len__(self): 2066 return 0 2067 q = Q() 2068 self.assertEqual(g(q), "sized") 2069 c.Iterable.register(Q) 2070 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ 2071 c.Set.register(Q) 2072 self.assertEqual(g(q), "set") # because c.Set is a subclass of 2073 # c.Sized and c.Iterable 2074 @functools.singledispatch 2075 def h(arg): 2076 return "base" 2077 @h.register(c.Sized) 2078 def _(arg): 2079 return "sized" 2080 @h.register(c.Container) 2081 def _(arg): 2082 return "container" 2083 # Even though Sized and Container are explicit bases of MutableMapping, 2084 # this ABC is implicitly registered on defaultdict which makes all of 2085 # MutableMapping's bases implicit as well from defaultdict's 2086 # perspective. 2087 with self.assertRaises(RuntimeError) as re_two: 2088 h(collections.defaultdict(lambda: 0)) 2089 self.assertIn( 2090 str(re_two.exception), 2091 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2092 "or <class 'collections.abc.Sized'>"), 2093 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2094 "or <class 'collections.abc.Container'>")), 2095 ) 2096 class R(collections.defaultdict): 2097 pass 2098 c.MutableSequence.register(R) 2099 @functools.singledispatch 2100 def i(arg): 2101 return "base" 2102 @i.register(c.MutableMapping) 2103 def _(arg): 2104 return "mapping" 2105 @i.register(c.MutableSequence) 2106 def _(arg): 2107 return "sequence" 2108 r = R() 2109 self.assertEqual(i(r), "sequence") 2110 class S: 2111 pass 2112 class T(S, c.Sized): 2113 def __len__(self): 2114 return 0 2115 t = T() 2116 self.assertEqual(h(t), "sized") 2117 c.Container.register(T) 2118 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO 2119 class U: 2120 def __len__(self): 2121 return 0 2122 u = U() 2123 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred 2124 # from the existence of __len__() 2125 c.Container.register(U) 2126 # There is no preference for registered versus inferred ABCs. 2127 with self.assertRaises(RuntimeError) as re_three: 2128 h(u) 2129 self.assertIn( 2130 str(re_three.exception), 2131 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2132 "or <class 'collections.abc.Sized'>"), 2133 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2134 "or <class 'collections.abc.Container'>")), 2135 ) 2136 class V(c.Sized, S): 2137 def __len__(self): 2138 return 0 2139 @functools.singledispatch 2140 def j(arg): 2141 return "base" 2142 @j.register(S) 2143 def _(arg): 2144 return "s" 2145 @j.register(c.Container) 2146 def _(arg): 2147 return "container" 2148 v = V() 2149 self.assertEqual(j(v), "s") 2150 c.Container.register(V) 2151 self.assertEqual(j(v), "container") # because it ends up right after 2152 # Sized in the MRO 2153 2154 def test_cache_invalidation(self): 2155 from collections import UserDict 2156 import weakref 2157 2158 class TracingDict(UserDict): 2159 def __init__(self, *args, **kwargs): 2160 super(TracingDict, self).__init__(*args, **kwargs) 2161 self.set_ops = [] 2162 self.get_ops = [] 2163 def __getitem__(self, key): 2164 result = self.data[key] 2165 self.get_ops.append(key) 2166 return result 2167 def __setitem__(self, key, value): 2168 self.set_ops.append(key) 2169 self.data[key] = value 2170 def clear(self): 2171 self.data.clear() 2172 2173 td = TracingDict() 2174 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td): 2175 c = collections.abc 2176 @functools.singledispatch 2177 def g(arg): 2178 return "base" 2179 d = {} 2180 l = [] 2181 self.assertEqual(len(td), 0) 2182 self.assertEqual(g(d), "base") 2183 self.assertEqual(len(td), 1) 2184 self.assertEqual(td.get_ops, []) 2185 self.assertEqual(td.set_ops, [dict]) 2186 self.assertEqual(td.data[dict], g.registry[object]) 2187 self.assertEqual(g(l), "base") 2188 self.assertEqual(len(td), 2) 2189 self.assertEqual(td.get_ops, []) 2190 self.assertEqual(td.set_ops, [dict, list]) 2191 self.assertEqual(td.data[dict], g.registry[object]) 2192 self.assertEqual(td.data[list], g.registry[object]) 2193 self.assertEqual(td.data[dict], td.data[list]) 2194 self.assertEqual(g(l), "base") 2195 self.assertEqual(g(d), "base") 2196 self.assertEqual(td.get_ops, [list, dict]) 2197 self.assertEqual(td.set_ops, [dict, list]) 2198 g.register(list, lambda arg: "list") 2199 self.assertEqual(td.get_ops, [list, dict]) 2200 self.assertEqual(len(td), 0) 2201 self.assertEqual(g(d), "base") 2202 self.assertEqual(len(td), 1) 2203 self.assertEqual(td.get_ops, [list, dict]) 2204 self.assertEqual(td.set_ops, [dict, list, dict]) 2205 self.assertEqual(td.data[dict], 2206 functools._find_impl(dict, g.registry)) 2207 self.assertEqual(g(l), "list") 2208 self.assertEqual(len(td), 2) 2209 self.assertEqual(td.get_ops, [list, dict]) 2210 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2211 self.assertEqual(td.data[list], 2212 functools._find_impl(list, g.registry)) 2213 class X: 2214 pass 2215 c.MutableMapping.register(X) # Will not invalidate the cache, 2216 # not using ABCs yet. 2217 self.assertEqual(g(d), "base") 2218 self.assertEqual(g(l), "list") 2219 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2220 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2221 g.register(c.Sized, lambda arg: "sized") 2222 self.assertEqual(len(td), 0) 2223 self.assertEqual(g(d), "sized") 2224 self.assertEqual(len(td), 1) 2225 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2226 self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) 2227 self.assertEqual(g(l), "list") 2228 self.assertEqual(len(td), 2) 2229 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2230 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2231 self.assertEqual(g(l), "list") 2232 self.assertEqual(g(d), "sized") 2233 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) 2234 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2235 g.dispatch(list) 2236 g.dispatch(dict) 2237 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, 2238 list, dict]) 2239 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2240 c.MutableSet.register(X) # Will invalidate the cache. 2241 self.assertEqual(len(td), 2) # Stale cache. 2242 self.assertEqual(g(l), "list") 2243 self.assertEqual(len(td), 1) 2244 g.register(c.MutableMapping, lambda arg: "mutablemapping") 2245 self.assertEqual(len(td), 0) 2246 self.assertEqual(g(d), "mutablemapping") 2247 self.assertEqual(len(td), 1) 2248 self.assertEqual(g(l), "list") 2249 self.assertEqual(len(td), 2) 2250 g.register(dict, lambda arg: "dict") 2251 self.assertEqual(g(d), "dict") 2252 self.assertEqual(g(l), "list") 2253 g._clear_cache() 2254 self.assertEqual(len(td), 0) 2255 2256 def test_annotations(self): 2257 @functools.singledispatch 2258 def i(arg): 2259 return "base" 2260 @i.register 2261 def _(arg: collections.abc.Mapping): 2262 return "mapping" 2263 @i.register 2264 def _(arg: "collections.abc.Sequence"): 2265 return "sequence" 2266 self.assertEqual(i(None), "base") 2267 self.assertEqual(i({"a": 1}), "mapping") 2268 self.assertEqual(i([1, 2, 3]), "sequence") 2269 self.assertEqual(i((1, 2, 3)), "sequence") 2270 self.assertEqual(i("str"), "sequence") 2271 2272 # Registering classes as callables doesn't work with annotations, 2273 # you need to pass the type explicitly. 2274 @i.register(str) 2275 class _: 2276 def __init__(self, arg): 2277 self.arg = arg 2278 2279 def __eq__(self, other): 2280 return self.arg == other 2281 self.assertEqual(i("str"), "str") 2282 2283 def test_method_register(self): 2284 class A: 2285 @functools.singledispatchmethod 2286 def t(self, arg): 2287 self.arg = "base" 2288 @t.register(int) 2289 def _(self, arg): 2290 self.arg = "int" 2291 @t.register(str) 2292 def _(self, arg): 2293 self.arg = "str" 2294 a = A() 2295 2296 a.t(0) 2297 self.assertEqual(a.arg, "int") 2298 aa = A() 2299 self.assertFalse(hasattr(aa, 'arg')) 2300 a.t('') 2301 self.assertEqual(a.arg, "str") 2302 aa = A() 2303 self.assertFalse(hasattr(aa, 'arg')) 2304 a.t(0.0) 2305 self.assertEqual(a.arg, "base") 2306 aa = A() 2307 self.assertFalse(hasattr(aa, 'arg')) 2308 2309 def test_staticmethod_register(self): 2310 class A: 2311 @functools.singledispatchmethod 2312 @staticmethod 2313 def t(arg): 2314 return arg 2315 @t.register(int) 2316 @staticmethod 2317 def _(arg): 2318 return isinstance(arg, int) 2319 @t.register(str) 2320 @staticmethod 2321 def _(arg): 2322 return isinstance(arg, str) 2323 a = A() 2324 2325 self.assertTrue(A.t(0)) 2326 self.assertTrue(A.t('')) 2327 self.assertEqual(A.t(0.0), 0.0) 2328 2329 def test_classmethod_register(self): 2330 class A: 2331 def __init__(self, arg): 2332 self.arg = arg 2333 2334 @functools.singledispatchmethod 2335 @classmethod 2336 def t(cls, arg): 2337 return cls("base") 2338 @t.register(int) 2339 @classmethod 2340 def _(cls, arg): 2341 return cls("int") 2342 @t.register(str) 2343 @classmethod 2344 def _(cls, arg): 2345 return cls("str") 2346 2347 self.assertEqual(A.t(0).arg, "int") 2348 self.assertEqual(A.t('').arg, "str") 2349 self.assertEqual(A.t(0.0).arg, "base") 2350 2351 def test_callable_register(self): 2352 class A: 2353 def __init__(self, arg): 2354 self.arg = arg 2355 2356 @functools.singledispatchmethod 2357 @classmethod 2358 def t(cls, arg): 2359 return cls("base") 2360 2361 @A.t.register(int) 2362 @classmethod 2363 def _(cls, arg): 2364 return cls("int") 2365 @A.t.register(str) 2366 @classmethod 2367 def _(cls, arg): 2368 return cls("str") 2369 2370 self.assertEqual(A.t(0).arg, "int") 2371 self.assertEqual(A.t('').arg, "str") 2372 self.assertEqual(A.t(0.0).arg, "base") 2373 2374 def test_abstractmethod_register(self): 2375 class Abstract(abc.ABCMeta): 2376 2377 @functools.singledispatchmethod 2378 @abc.abstractmethod 2379 def add(self, x, y): 2380 pass 2381 2382 self.assertTrue(Abstract.add.__isabstractmethod__) 2383 2384 def test_type_ann_register(self): 2385 class A: 2386 @functools.singledispatchmethod 2387 def t(self, arg): 2388 return "base" 2389 @t.register 2390 def _(self, arg: int): 2391 return "int" 2392 @t.register 2393 def _(self, arg: str): 2394 return "str" 2395 a = A() 2396 2397 self.assertEqual(a.t(0), "int") 2398 self.assertEqual(a.t(''), "str") 2399 self.assertEqual(a.t(0.0), "base") 2400 2401 def test_invalid_registrations(self): 2402 msg_prefix = "Invalid first argument to `register()`: " 2403 msg_suffix = ( 2404 ". Use either `@register(some_class)` or plain `@register` on an " 2405 "annotated function." 2406 ) 2407 @functools.singledispatch 2408 def i(arg): 2409 return "base" 2410 with self.assertRaises(TypeError) as exc: 2411 @i.register(42) 2412 def _(arg): 2413 return "I annotated with a non-type" 2414 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42")) 2415 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2416 with self.assertRaises(TypeError) as exc: 2417 @i.register 2418 def _(arg): 2419 return "I forgot to annotate" 2420 self.assertTrue(str(exc.exception).startswith(msg_prefix + 2421 "<function TestSingleDispatch.test_invalid_registrations.<locals>._" 2422 )) 2423 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2424 2425 with self.assertRaises(TypeError) as exc: 2426 @i.register 2427 def _(arg: typing.Iterable[str]): 2428 # At runtime, dispatching on generics is impossible. 2429 # When registering implementations with singledispatch, avoid 2430 # types from `typing`. Instead, annotate with regular types 2431 # or ABCs. 2432 return "I annotated with a generic collection" 2433 self.assertTrue(str(exc.exception).startswith( 2434 "Invalid annotation for 'arg'." 2435 )) 2436 self.assertTrue(str(exc.exception).endswith( 2437 'typing.Iterable[str] is not a class.' 2438 )) 2439 2440 def test_invalid_positional_argument(self): 2441 @functools.singledispatch 2442 def f(*args): 2443 pass 2444 msg = 'f requires at least 1 positional argument' 2445 with self.assertRaisesRegex(TypeError, msg): 2446 f() 2447 2448 2449class CachedCostItem: 2450 _cost = 1 2451 2452 def __init__(self): 2453 self.lock = py_functools.RLock() 2454 2455 @py_functools.cached_property 2456 def cost(self): 2457 """The cost of the item.""" 2458 with self.lock: 2459 self._cost += 1 2460 return self._cost 2461 2462 2463class OptionallyCachedCostItem: 2464 _cost = 1 2465 2466 def get_cost(self): 2467 """The cost of the item.""" 2468 self._cost += 1 2469 return self._cost 2470 2471 cached_cost = py_functools.cached_property(get_cost) 2472 2473 2474class CachedCostItemWait: 2475 2476 def __init__(self, event): 2477 self._cost = 1 2478 self.lock = py_functools.RLock() 2479 self.event = event 2480 2481 @py_functools.cached_property 2482 def cost(self): 2483 self.event.wait(1) 2484 with self.lock: 2485 self._cost += 1 2486 return self._cost 2487 2488 2489class CachedCostItemWithSlots: 2490 __slots__ = ('_cost') 2491 2492 def __init__(self): 2493 self._cost = 1 2494 2495 @py_functools.cached_property 2496 def cost(self): 2497 raise RuntimeError('never called, slots not supported') 2498 2499 2500class TestCachedProperty(unittest.TestCase): 2501 def test_cached(self): 2502 item = CachedCostItem() 2503 self.assertEqual(item.cost, 2) 2504 self.assertEqual(item.cost, 2) # not 3 2505 2506 def test_cached_attribute_name_differs_from_func_name(self): 2507 item = OptionallyCachedCostItem() 2508 self.assertEqual(item.get_cost(), 2) 2509 self.assertEqual(item.cached_cost, 3) 2510 self.assertEqual(item.get_cost(), 4) 2511 self.assertEqual(item.cached_cost, 3) 2512 2513 def test_threaded(self): 2514 go = threading.Event() 2515 item = CachedCostItemWait(go) 2516 2517 num_threads = 3 2518 2519 orig_si = sys.getswitchinterval() 2520 sys.setswitchinterval(1e-6) 2521 try: 2522 threads = [ 2523 threading.Thread(target=lambda: item.cost) 2524 for k in range(num_threads) 2525 ] 2526 with support.start_threads(threads): 2527 go.set() 2528 finally: 2529 sys.setswitchinterval(orig_si) 2530 2531 self.assertEqual(item.cost, 2) 2532 2533 def test_object_with_slots(self): 2534 item = CachedCostItemWithSlots() 2535 with self.assertRaisesRegex( 2536 TypeError, 2537 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.", 2538 ): 2539 item.cost 2540 2541 def test_immutable_dict(self): 2542 class MyMeta(type): 2543 @py_functools.cached_property 2544 def prop(self): 2545 return True 2546 2547 class MyClass(metaclass=MyMeta): 2548 pass 2549 2550 with self.assertRaisesRegex( 2551 TypeError, 2552 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.", 2553 ): 2554 MyClass.prop 2555 2556 def test_reuse_different_names(self): 2557 """Disallow this case because decorated function a would not be cached.""" 2558 with self.assertRaises(RuntimeError) as ctx: 2559 class ReusedCachedProperty: 2560 @py_functools.cached_property 2561 def a(self): 2562 pass 2563 2564 b = a 2565 2566 self.assertEqual( 2567 str(ctx.exception.__context__), 2568 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b').")) 2569 ) 2570 2571 def test_reuse_same_name(self): 2572 """Reusing a cached_property on different classes under the same name is OK.""" 2573 counter = 0 2574 2575 @py_functools.cached_property 2576 def _cp(_self): 2577 nonlocal counter 2578 counter += 1 2579 return counter 2580 2581 class A: 2582 cp = _cp 2583 2584 class B: 2585 cp = _cp 2586 2587 a = A() 2588 b = B() 2589 2590 self.assertEqual(a.cp, 1) 2591 self.assertEqual(b.cp, 2) 2592 self.assertEqual(a.cp, 1) 2593 2594 def test_set_name_not_called(self): 2595 cp = py_functools.cached_property(lambda s: None) 2596 class Foo: 2597 pass 2598 2599 Foo.cp = cp 2600 2601 with self.assertRaisesRegex( 2602 TypeError, 2603 "Cannot use cached_property instance without calling __set_name__ on it.", 2604 ): 2605 Foo().cp 2606 2607 def test_access_from_class(self): 2608 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property) 2609 2610 def test_doc(self): 2611 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.") 2612 2613 2614if __name__ == '__main__': 2615 unittest.main() 2616