1import unittest 2from test import test_support 3 4def funcattrs(**kwds): 5 def decorate(func): 6 func.__dict__.update(kwds) 7 return func 8 return decorate 9 10class MiscDecorators (object): 11 @staticmethod 12 def author(name): 13 def decorate(func): 14 func.__dict__['author'] = name 15 return func 16 return decorate 17 18# ----------------------------------------------- 19 20class DbcheckError (Exception): 21 def __init__(self, exprstr, func, args, kwds): 22 # A real version of this would set attributes here 23 Exception.__init__(self, "dbcheck %r failed (func=%s args=%s kwds=%s)" % 24 (exprstr, func, args, kwds)) 25 26 27def dbcheck(exprstr, globals=None, locals=None): 28 "Decorator to implement debugging assertions" 29 def decorate(func): 30 expr = compile(exprstr, "dbcheck-%s" % func.func_name, "eval") 31 def check(*args, **kwds): 32 if not eval(expr, globals, locals): 33 raise DbcheckError(exprstr, func, args, kwds) 34 return func(*args, **kwds) 35 return check 36 return decorate 37 38# ----------------------------------------------- 39 40def countcalls(counts): 41 "Decorator to count calls to a function" 42 def decorate(func): 43 func_name = func.func_name 44 counts[func_name] = 0 45 def call(*args, **kwds): 46 counts[func_name] += 1 47 return func(*args, **kwds) 48 call.func_name = func_name 49 return call 50 return decorate 51 52# ----------------------------------------------- 53 54def memoize(func): 55 saved = {} 56 def call(*args): 57 try: 58 return saved[args] 59 except KeyError: 60 res = func(*args) 61 saved[args] = res 62 return res 63 except TypeError: 64 # Unhashable argument 65 return func(*args) 66 call.func_name = func.func_name 67 return call 68 69# ----------------------------------------------- 70 71class TestDecorators(unittest.TestCase): 72 73 def test_single(self): 74 class C(object): 75 @staticmethod 76 def foo(): return 42 77 self.assertEqual(C.foo(), 42) 78 self.assertEqual(C().foo(), 42) 79 80 def test_staticmethod_function(self): 81 @staticmethod 82 def notamethod(x): 83 return x 84 self.assertRaises(TypeError, notamethod, 1) 85 86 def test_dotted(self): 87 decorators = MiscDecorators() 88 @decorators.author('Cleese') 89 def foo(): return 42 90 self.assertEqual(foo(), 42) 91 self.assertEqual(foo.author, 'Cleese') 92 93 def test_argforms(self): 94 # A few tests of argument passing, as we use restricted form 95 # of expressions for decorators. 96 97 def noteargs(*args, **kwds): 98 def decorate(func): 99 setattr(func, 'dbval', (args, kwds)) 100 return func 101 return decorate 102 103 args = ( 'Now', 'is', 'the', 'time' ) 104 kwds = dict(one=1, two=2) 105 @noteargs(*args, **kwds) 106 def f1(): return 42 107 self.assertEqual(f1(), 42) 108 self.assertEqual(f1.dbval, (args, kwds)) 109 110 @noteargs('terry', 'gilliam', eric='idle', john='cleese') 111 def f2(): return 84 112 self.assertEqual(f2(), 84) 113 self.assertEqual(f2.dbval, (('terry', 'gilliam'), 114 dict(eric='idle', john='cleese'))) 115 116 @noteargs(1, 2,) 117 def f3(): pass 118 self.assertEqual(f3.dbval, ((1, 2), {})) 119 120 def test_dbcheck(self): 121 @dbcheck('args[1] is not None') 122 def f(a, b): 123 return a + b 124 self.assertEqual(f(1, 2), 3) 125 self.assertRaises(DbcheckError, f, 1, None) 126 127 def test_memoize(self): 128 counts = {} 129 130 @memoize 131 @countcalls(counts) 132 def double(x): 133 return x * 2 134 self.assertEqual(double.func_name, 'double') 135 136 self.assertEqual(counts, dict(double=0)) 137 138 # Only the first call with a given argument bumps the call count: 139 # 140 self.assertEqual(double(2), 4) 141 self.assertEqual(counts['double'], 1) 142 self.assertEqual(double(2), 4) 143 self.assertEqual(counts['double'], 1) 144 self.assertEqual(double(3), 6) 145 self.assertEqual(counts['double'], 2) 146 147 # Unhashable arguments do not get memoized: 148 # 149 self.assertEqual(double([10]), [10, 10]) 150 self.assertEqual(counts['double'], 3) 151 self.assertEqual(double([10]), [10, 10]) 152 self.assertEqual(counts['double'], 4) 153 154 def test_errors(self): 155 # Test syntax restrictions - these are all compile-time errors: 156 # 157 for expr in [ "1+2", "x[3]", "(1, 2)" ]: 158 # Sanity check: is expr is a valid expression by itself? 159 compile(expr, "testexpr", "exec") 160 161 codestr = "@%s\ndef f(): pass" % expr 162 self.assertRaises(SyntaxError, compile, codestr, "test", "exec") 163 164 # You can't put multiple decorators on a single line: 165 # 166 self.assertRaises(SyntaxError, compile, 167 "@f1 @f2\ndef f(): pass", "test", "exec") 168 169 # Test runtime errors 170 171 def unimp(func): 172 raise NotImplementedError 173 context = dict(nullval=None, unimp=unimp) 174 175 for expr, exc in [ ("undef", NameError), 176 ("nullval", TypeError), 177 ("nullval.attr", AttributeError), 178 ("unimp", NotImplementedError)]: 179 codestr = "@%s\ndef f(): pass\nassert f() is None" % expr 180 code = compile(codestr, "test", "exec") 181 self.assertRaises(exc, eval, code, context) 182 183 def test_double(self): 184 class C(object): 185 @funcattrs(abc=1, xyz="haha") 186 @funcattrs(booh=42) 187 def foo(self): return 42 188 self.assertEqual(C().foo(), 42) 189 self.assertEqual(C.foo.abc, 1) 190 self.assertEqual(C.foo.xyz, "haha") 191 self.assertEqual(C.foo.booh, 42) 192 193 def test_order(self): 194 # Test that decorators are applied in the proper order to the function 195 # they are decorating. 196 def callnum(num): 197 """Decorator factory that returns a decorator that replaces the 198 passed-in function with one that returns the value of 'num'""" 199 def deco(func): 200 return lambda: num 201 return deco 202 @callnum(2) 203 @callnum(1) 204 def foo(): return 42 205 self.assertEqual(foo(), 2, 206 "Application order of decorators is incorrect") 207 208 def test_eval_order(self): 209 # Evaluating a decorated function involves four steps for each 210 # decorator-maker (the function that returns a decorator): 211 # 212 # 1: Evaluate the decorator-maker name 213 # 2: Evaluate the decorator-maker arguments (if any) 214 # 3: Call the decorator-maker to make a decorator 215 # 4: Call the decorator 216 # 217 # When there are multiple decorators, these steps should be 218 # performed in the above order for each decorator, but we should 219 # iterate through the decorators in the reverse of the order they 220 # appear in the source. 221 222 actions = [] 223 224 def make_decorator(tag): 225 actions.append('makedec' + tag) 226 def decorate(func): 227 actions.append('calldec' + tag) 228 return func 229 return decorate 230 231 class NameLookupTracer (object): 232 def __init__(self, index): 233 self.index = index 234 235 def __getattr__(self, fname): 236 if fname == 'make_decorator': 237 opname, res = ('evalname', make_decorator) 238 elif fname == 'arg': 239 opname, res = ('evalargs', str(self.index)) 240 else: 241 assert False, "Unknown attrname %s" % fname 242 actions.append('%s%d' % (opname, self.index)) 243 return res 244 245 c1, c2, c3 = map(NameLookupTracer, [ 1, 2, 3 ]) 246 247 expected_actions = [ 'evalname1', 'evalargs1', 'makedec1', 248 'evalname2', 'evalargs2', 'makedec2', 249 'evalname3', 'evalargs3', 'makedec3', 250 'calldec3', 'calldec2', 'calldec1' ] 251 252 actions = [] 253 @c1.make_decorator(c1.arg) 254 @c2.make_decorator(c2.arg) 255 @c3.make_decorator(c3.arg) 256 def foo(): return 42 257 self.assertEqual(foo(), 42) 258 259 self.assertEqual(actions, expected_actions) 260 261 # Test the equivalence claim in chapter 7 of the reference manual. 262 # 263 actions = [] 264 def bar(): return 42 265 bar = c1.make_decorator(c1.arg)(c2.make_decorator(c2.arg)(c3.make_decorator(c3.arg)(bar))) 266 self.assertEqual(bar(), 42) 267 self.assertEqual(actions, expected_actions) 268 269class TestClassDecorators(unittest.TestCase): 270 271 def test_simple(self): 272 def plain(x): 273 x.extra = 'Hello' 274 return x 275 @plain 276 class C(object): pass 277 self.assertEqual(C.extra, 'Hello') 278 279 def test_double(self): 280 def ten(x): 281 x.extra = 10 282 return x 283 def add_five(x): 284 x.extra += 5 285 return x 286 287 @add_five 288 @ten 289 class C(object): pass 290 self.assertEqual(C.extra, 15) 291 292 def test_order(self): 293 def applied_first(x): 294 x.extra = 'first' 295 return x 296 def applied_second(x): 297 x.extra = 'second' 298 return x 299 @applied_second 300 @applied_first 301 class C(object): pass 302 self.assertEqual(C.extra, 'second') 303 304def test_main(): 305 test_support.run_unittest(TestDecorators) 306 test_support.run_unittest(TestClassDecorators) 307 308if __name__=="__main__": 309 test_main() 310