1import unittest
2
3GLOBAL_VAR = None
4
5class NamedExpressionInvalidTest(unittest.TestCase):
6
7    def test_named_expression_invalid_01(self):
8        code = """x := 0"""
9
10        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
11            exec(code, {}, {})
12
13    def test_named_expression_invalid_02(self):
14        code = """x = y := 0"""
15
16        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
17            exec(code, {}, {})
18
19    def test_named_expression_invalid_03(self):
20        code = """y := f(x)"""
21
22        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
23            exec(code, {}, {})
24
25    def test_named_expression_invalid_04(self):
26        code = """y0 = y1 := f(x)"""
27
28        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
29            exec(code, {}, {})
30
31    def test_named_expression_invalid_06(self):
32        code = """((a, b) := (1, 2))"""
33
34        with self.assertRaisesRegex(SyntaxError, "cannot use assignment expressions with tuple"):
35            exec(code, {}, {})
36
37    def test_named_expression_invalid_07(self):
38        code = """def spam(a = b := 42): pass"""
39
40        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
41            exec(code, {}, {})
42
43    def test_named_expression_invalid_08(self):
44        code = """def spam(a: b := 42 = 5): pass"""
45
46        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
47            exec(code, {}, {})
48
49    def test_named_expression_invalid_09(self):
50        code = """spam(a=b := 'c')"""
51
52        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
53            exec(code, {}, {})
54
55    def test_named_expression_invalid_10(self):
56        code = """spam(x = y := f(x))"""
57
58        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
59            exec(code, {}, {})
60
61    def test_named_expression_invalid_11(self):
62        code = """spam(a=1, b := 2)"""
63
64        with self.assertRaisesRegex(SyntaxError,
65            "positional argument follows keyword argument"):
66            exec(code, {}, {})
67
68    def test_named_expression_invalid_12(self):
69        code = """spam(a=1, (b := 2))"""
70
71        with self.assertRaisesRegex(SyntaxError,
72            "positional argument follows keyword argument"):
73            exec(code, {}, {})
74
75    def test_named_expression_invalid_13(self):
76        code = """spam(a=1, (b := 2))"""
77
78        with self.assertRaisesRegex(SyntaxError,
79            "positional argument follows keyword argument"):
80            exec(code, {}, {})
81
82    def test_named_expression_invalid_14(self):
83        code = """(x := lambda: y := 1)"""
84
85        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
86            exec(code, {}, {})
87
88    def test_named_expression_invalid_15(self):
89        code = """(lambda: x := 1)"""
90
91        with self.assertRaisesRegex(SyntaxError,
92            "cannot use assignment expressions with lambda"):
93            exec(code, {}, {})
94
95    def test_named_expression_invalid_16(self):
96        code = "[i + 1 for i in i := [1,2]]"
97
98        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
99            exec(code, {}, {})
100
101    def test_named_expression_invalid_17(self):
102        code = "[i := 0, j := 1 for i, j in [(1, 2), (3, 4)]]"
103
104        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
105            exec(code, {}, {})
106
107    def test_named_expression_invalid_in_class_body(self):
108        code = """class Foo():
109            [(42, 1 + ((( j := i )))) for i in range(5)]
110        """
111
112        with self.assertRaisesRegex(SyntaxError,
113            "assignment expression within a comprehension cannot be used in a class body"):
114            exec(code, {}, {})
115
116    def test_named_expression_invalid_rebinding_list_comprehension_iteration_variable(self):
117        cases = [
118            ("Local reuse", 'i', "[i := 0 for i in range(5)]"),
119            ("Nested reuse", 'j', "[[(j := 0) for i in range(5)] for j in range(5)]"),
120            ("Reuse inner loop target", 'j', "[(j := 0) for i in range(5) for j in range(5)]"),
121            ("Unpacking reuse", 'i', "[i := 0 for i, j in [(0, 1)]]"),
122            ("Reuse in loop condition", 'i', "[i+1 for i in range(5) if (i := 0)]"),
123            ("Unreachable reuse", 'i', "[False or (i:=0) for i in range(5)]"),
124            ("Unreachable nested reuse", 'i',
125                "[(i, j) for i in range(5) for j in range(5) if True or (i:=10)]"),
126        ]
127        for case, target, code in cases:
128            msg = f"assignment expression cannot rebind comprehension iteration variable '{target}'"
129            with self.subTest(case=case):
130                with self.assertRaisesRegex(SyntaxError, msg):
131                    exec(code, {}, {})
132
133    def test_named_expression_invalid_rebinding_list_comprehension_inner_loop(self):
134        cases = [
135            ("Inner reuse", 'j', "[i for i in range(5) if (j := 0) for j in range(5)]"),
136            ("Inner unpacking reuse", 'j', "[i for i in range(5) if (j := 0) for j, k in [(0, 1)]]"),
137        ]
138        for case, target, code in cases:
139            msg = f"comprehension inner loop cannot rebind assignment expression target '{target}'"
140            with self.subTest(case=case):
141                with self.assertRaisesRegex(SyntaxError, msg):
142                    exec(code, {}) # Module scope
143                with self.assertRaisesRegex(SyntaxError, msg):
144                    exec(code, {}, {}) # Class scope
145                with self.assertRaisesRegex(SyntaxError, msg):
146                    exec(f"lambda: {code}", {}) # Function scope
147
148    def test_named_expression_invalid_list_comprehension_iterable_expression(self):
149        cases = [
150            ("Top level", "[i for i in (i := range(5))]"),
151            ("Inside tuple", "[i for i in (2, 3, i := range(5))]"),
152            ("Inside list", "[i for i in [2, 3, i := range(5)]]"),
153            ("Different name", "[i for i in (j := range(5))]"),
154            ("Lambda expression", "[i for i in (lambda:(j := range(5)))()]"),
155            ("Inner loop", "[i for i in range(5) for j in (i := range(5))]"),
156            ("Nested comprehension", "[i for i in [j for j in (k := range(5))]]"),
157            ("Nested comprehension condition", "[i for i in [j for j in range(5) if (j := True)]]"),
158            ("Nested comprehension body", "[i for i in [(j := True) for j in range(5)]]"),
159        ]
160        msg = "assignment expression cannot be used in a comprehension iterable expression"
161        for case, code in cases:
162            with self.subTest(case=case):
163                with self.assertRaisesRegex(SyntaxError, msg):
164                    exec(code, {}) # Module scope
165                with self.assertRaisesRegex(SyntaxError, msg):
166                    exec(code, {}, {}) # Class scope
167                with self.assertRaisesRegex(SyntaxError, msg):
168                    exec(f"lambda: {code}", {}) # Function scope
169
170    def test_named_expression_invalid_rebinding_set_comprehension_iteration_variable(self):
171        cases = [
172            ("Local reuse", 'i', "{i := 0 for i in range(5)}"),
173            ("Nested reuse", 'j', "{{(j := 0) for i in range(5)} for j in range(5)}"),
174            ("Reuse inner loop target", 'j', "{(j := 0) for i in range(5) for j in range(5)}"),
175            ("Unpacking reuse", 'i', "{i := 0 for i, j in {(0, 1)}}"),
176            ("Reuse in loop condition", 'i', "{i+1 for i in range(5) if (i := 0)}"),
177            ("Unreachable reuse", 'i', "{False or (i:=0) for i in range(5)}"),
178            ("Unreachable nested reuse", 'i',
179                "{(i, j) for i in range(5) for j in range(5) if True or (i:=10)}"),
180        ]
181        for case, target, code in cases:
182            msg = f"assignment expression cannot rebind comprehension iteration variable '{target}'"
183            with self.subTest(case=case):
184                with self.assertRaisesRegex(SyntaxError, msg):
185                    exec(code, {}, {})
186
187    def test_named_expression_invalid_rebinding_set_comprehension_inner_loop(self):
188        cases = [
189            ("Inner reuse", 'j', "{i for i in range(5) if (j := 0) for j in range(5)}"),
190            ("Inner unpacking reuse", 'j', "{i for i in range(5) if (j := 0) for j, k in {(0, 1)}}"),
191        ]
192        for case, target, code in cases:
193            msg = f"comprehension inner loop cannot rebind assignment expression target '{target}'"
194            with self.subTest(case=case):
195                with self.assertRaisesRegex(SyntaxError, msg):
196                    exec(code, {}) # Module scope
197                with self.assertRaisesRegex(SyntaxError, msg):
198                    exec(code, {}, {}) # Class scope
199                with self.assertRaisesRegex(SyntaxError, msg):
200                    exec(f"lambda: {code}", {}) # Function scope
201
202    def test_named_expression_invalid_set_comprehension_iterable_expression(self):
203        cases = [
204            ("Top level", "{i for i in (i := range(5))}"),
205            ("Inside tuple", "{i for i in (2, 3, i := range(5))}"),
206            ("Inside list", "{i for i in {2, 3, i := range(5)}}"),
207            ("Different name", "{i for i in (j := range(5))}"),
208            ("Lambda expression", "{i for i in (lambda:(j := range(5)))()}"),
209            ("Inner loop", "{i for i in range(5) for j in (i := range(5))}"),
210            ("Nested comprehension", "{i for i in {j for j in (k := range(5))}}"),
211            ("Nested comprehension condition", "{i for i in {j for j in range(5) if (j := True)}}"),
212            ("Nested comprehension body", "{i for i in {(j := True) for j in range(5)}}"),
213        ]
214        msg = "assignment expression cannot be used in a comprehension iterable expression"
215        for case, code in cases:
216            with self.subTest(case=case):
217                with self.assertRaisesRegex(SyntaxError, msg):
218                    exec(code, {}) # Module scope
219                with self.assertRaisesRegex(SyntaxError, msg):
220                    exec(code, {}, {}) # Class scope
221                with self.assertRaisesRegex(SyntaxError, msg):
222                    exec(f"lambda: {code}", {}) # Function scope
223
224
225class NamedExpressionAssignmentTest(unittest.TestCase):
226
227    def test_named_expression_assignment_01(self):
228        (a := 10)
229
230        self.assertEqual(a, 10)
231
232    def test_named_expression_assignment_02(self):
233        a = 20
234        (a := a)
235
236        self.assertEqual(a, 20)
237
238    def test_named_expression_assignment_03(self):
239        (total := 1 + 2)
240
241        self.assertEqual(total, 3)
242
243    def test_named_expression_assignment_04(self):
244        (info := (1, 2, 3))
245
246        self.assertEqual(info, (1, 2, 3))
247
248    def test_named_expression_assignment_05(self):
249        (x := 1, 2)
250
251        self.assertEqual(x, 1)
252
253    def test_named_expression_assignment_06(self):
254        (z := (y := (x := 0)))
255
256        self.assertEqual(x, 0)
257        self.assertEqual(y, 0)
258        self.assertEqual(z, 0)
259
260    def test_named_expression_assignment_07(self):
261        (loc := (1, 2))
262
263        self.assertEqual(loc, (1, 2))
264
265    def test_named_expression_assignment_08(self):
266        if spam := "eggs":
267            self.assertEqual(spam, "eggs")
268        else: self.fail("variable was not assigned using named expression")
269
270    def test_named_expression_assignment_09(self):
271        if True and (spam := True):
272            self.assertTrue(spam)
273        else: self.fail("variable was not assigned using named expression")
274
275    def test_named_expression_assignment_10(self):
276        if (match := 10) == 10:
277            pass
278        else: self.fail("variable was not assigned using named expression")
279
280    def test_named_expression_assignment_11(self):
281        def spam(a):
282            return a
283        input_data = [1, 2, 3]
284        res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
285
286        self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
287
288    def test_named_expression_assignment_12(self):
289        def spam(a):
290            return a
291        res = [[y := spam(x), x/y] for x in range(1, 5)]
292
293        self.assertEqual(res, [[1, 1.0], [2, 1.0], [3, 1.0], [4, 1.0]])
294
295    def test_named_expression_assignment_13(self):
296        length = len(lines := [1, 2])
297
298        self.assertEqual(length, 2)
299        self.assertEqual(lines, [1,2])
300
301    def test_named_expression_assignment_14(self):
302        """
303        Where all variables are positive integers, and a is at least as large
304        as the n'th root of x, this algorithm returns the floor of the n'th
305        root of x (and roughly doubling the number of accurate bits per
306        iteration):
307        """
308        a = 9
309        n = 2
310        x = 3
311
312        while a > (d := x // a**(n-1)):
313            a = ((n-1)*a + d) // n
314
315        self.assertEqual(a, 1)
316
317    def test_named_expression_assignment_15(self):
318        while a := False:
319            pass  # This will not run
320
321        self.assertEqual(a, False)
322
323    def test_named_expression_assignment_16(self):
324        a, b = 1, 2
325        fib = {(c := a): (a := b) + (b := a + c) - b for __ in range(6)}
326        self.assertEqual(fib, {1: 2, 2: 3, 3: 5, 5: 8, 8: 13, 13: 21})
327
328
329class NamedExpressionScopeTest(unittest.TestCase):
330
331    def test_named_expression_scope_01(self):
332        code = """def spam():
333    (a := 5)
334print(a)"""
335
336        with self.assertRaisesRegex(NameError, "name 'a' is not defined"):
337            exec(code, {}, {})
338
339    def test_named_expression_scope_02(self):
340        total = 0
341        partial_sums = [total := total + v for v in range(5)]
342
343        self.assertEqual(partial_sums, [0, 1, 3, 6, 10])
344        self.assertEqual(total, 10)
345
346    def test_named_expression_scope_03(self):
347        containsOne = any((lastNum := num) == 1 for num in [1, 2, 3])
348
349        self.assertTrue(containsOne)
350        self.assertEqual(lastNum, 1)
351
352    def test_named_expression_scope_04(self):
353        def spam(a):
354            return a
355        res = [[y := spam(x), x/y] for x in range(1, 5)]
356
357        self.assertEqual(y, 4)
358
359    def test_named_expression_scope_05(self):
360        def spam(a):
361            return a
362        input_data = [1, 2, 3]
363        res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
364
365        self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
366        self.assertEqual(y, 3)
367
368    def test_named_expression_scope_06(self):
369        res = [[spam := i for i in range(3)] for j in range(2)]
370
371        self.assertEqual(res, [[0, 1, 2], [0, 1, 2]])
372        self.assertEqual(spam, 2)
373
374    def test_named_expression_scope_07(self):
375        len(lines := [1, 2])
376
377        self.assertEqual(lines, [1, 2])
378
379    def test_named_expression_scope_08(self):
380        def spam(a):
381            return a
382
383        def eggs(b):
384            return b * 2
385
386        res = [spam(a := eggs(b := h)) for h in range(2)]
387
388        self.assertEqual(res, [0, 2])
389        self.assertEqual(a, 2)
390        self.assertEqual(b, 1)
391
392    def test_named_expression_scope_09(self):
393        def spam(a):
394            return a
395
396        def eggs(b):
397            return b * 2
398
399        res = [spam(a := eggs(a := h)) for h in range(2)]
400
401        self.assertEqual(res, [0, 2])
402        self.assertEqual(a, 2)
403
404    def test_named_expression_scope_10(self):
405        res = [b := [a := 1 for i in range(2)] for j in range(2)]
406
407        self.assertEqual(res, [[1, 1], [1, 1]])
408        self.assertEqual(a, 1)
409        self.assertEqual(b, [1, 1])
410
411    def test_named_expression_scope_11(self):
412        res = [j := i for i in range(5)]
413
414        self.assertEqual(res, [0, 1, 2, 3, 4])
415        self.assertEqual(j, 4)
416
417    def test_named_expression_scope_17(self):
418        b = 0
419        res = [b := i + b for i in range(5)]
420
421        self.assertEqual(res, [0, 1, 3, 6, 10])
422        self.assertEqual(b, 10)
423
424    def test_named_expression_scope_18(self):
425        def spam(a):
426            return a
427
428        res = spam(b := 2)
429
430        self.assertEqual(res, 2)
431        self.assertEqual(b, 2)
432
433    def test_named_expression_scope_19(self):
434        def spam(a):
435            return a
436
437        res = spam((b := 2))
438
439        self.assertEqual(res, 2)
440        self.assertEqual(b, 2)
441
442    def test_named_expression_scope_20(self):
443        def spam(a):
444            return a
445
446        res = spam(a=(b := 2))
447
448        self.assertEqual(res, 2)
449        self.assertEqual(b, 2)
450
451    def test_named_expression_scope_21(self):
452        def spam(a, b):
453            return a + b
454
455        res = spam(c := 2, b=1)
456
457        self.assertEqual(res, 3)
458        self.assertEqual(c, 2)
459
460    def test_named_expression_scope_22(self):
461        def spam(a, b):
462            return a + b
463
464        res = spam((c := 2), b=1)
465
466        self.assertEqual(res, 3)
467        self.assertEqual(c, 2)
468
469    def test_named_expression_scope_23(self):
470        def spam(a, b):
471            return a + b
472
473        res = spam(b=(c := 2), a=1)
474
475        self.assertEqual(res, 3)
476        self.assertEqual(c, 2)
477
478    def test_named_expression_scope_24(self):
479        a = 10
480        def spam():
481            nonlocal a
482            (a := 20)
483        spam()
484
485        self.assertEqual(a, 20)
486
487    def test_named_expression_scope_25(self):
488        ns = {}
489        code = """a = 10
490def spam():
491    global a
492    (a := 20)
493spam()"""
494
495        exec(code, ns, {})
496
497        self.assertEqual(ns["a"], 20)
498
499    def test_named_expression_variable_reuse_in_comprehensions(self):
500        # The compiler is expected to raise syntax error for comprehension
501        # iteration variables, but should be fine with rebinding of other
502        # names (e.g. globals, nonlocals, other assignment expressions)
503
504        # The cases are all defined to produce the same expected result
505        # Each comprehension is checked at both function scope and module scope
506        rebinding = "[x := i for i in range(3) if (x := i) or not x]"
507        filter_ref = "[x := i for i in range(3) if x or not x]"
508        body_ref = "[x for i in range(3) if (x := i) or not x]"
509        nested_ref = "[j for i in range(3) if x or not x for j in range(3) if (x := i)][:-3]"
510        cases = [
511            ("Rebind global", f"x = 1; result = {rebinding}"),
512            ("Rebind nonlocal", f"result, x = (lambda x=1: ({rebinding}, x))()"),
513            ("Filter global", f"x = 1; result = {filter_ref}"),
514            ("Filter nonlocal", f"result, x = (lambda x=1: ({filter_ref}, x))()"),
515            ("Body global", f"x = 1; result = {body_ref}"),
516            ("Body nonlocal", f"result, x = (lambda x=1: ({body_ref}, x))()"),
517            ("Nested global", f"x = 1; result = {nested_ref}"),
518            ("Nested nonlocal", f"result, x = (lambda x=1: ({nested_ref}, x))()"),
519        ]
520        for case, code in cases:
521            with self.subTest(case=case):
522                ns = {}
523                exec(code, ns)
524                self.assertEqual(ns["x"], 2)
525                self.assertEqual(ns["result"], [0, 1, 2])
526
527    def test_named_expression_global_scope(self):
528        sentinel = object()
529        global GLOBAL_VAR
530        def f():
531            global GLOBAL_VAR
532            [GLOBAL_VAR := sentinel for _ in range(1)]
533            self.assertEqual(GLOBAL_VAR, sentinel)
534        try:
535            f()
536            self.assertEqual(GLOBAL_VAR, sentinel)
537        finally:
538            GLOBAL_VAR = None
539
540    def test_named_expression_global_scope_no_global_keyword(self):
541        sentinel = object()
542        def f():
543            GLOBAL_VAR = None
544            [GLOBAL_VAR := sentinel for _ in range(1)]
545            self.assertEqual(GLOBAL_VAR, sentinel)
546        f()
547        self.assertEqual(GLOBAL_VAR, None)
548
549    def test_named_expression_nonlocal_scope(self):
550        sentinel = object()
551        def f():
552            nonlocal_var = None
553            def g():
554                nonlocal nonlocal_var
555                [nonlocal_var := sentinel for _ in range(1)]
556            g()
557            self.assertEqual(nonlocal_var, sentinel)
558        f()
559
560    def test_named_expression_nonlocal_scope_no_nonlocal_keyword(self):
561        sentinel = object()
562        def f():
563            nonlocal_var = None
564            def g():
565                [nonlocal_var := sentinel for _ in range(1)]
566            g()
567            self.assertEqual(nonlocal_var, None)
568        f()
569
570    def test_named_expression_scope_in_genexp(self):
571        a = 1
572        b = [1, 2, 3, 4]
573        genexp = (c := i + a for i in b)
574
575        self.assertNotIn("c", locals())
576        for idx, elem in enumerate(genexp):
577            self.assertEqual(elem, b[idx] + a)
578
579
580if __name__ == "__main__":
581    unittest.main()
582