1# Test iterators.
2
3import unittest
4from test.test_support import run_unittest, TESTFN, unlink, have_unicode, \
5                              check_py3k_warnings, cpython_only, \
6                              check_free_after_iterating
7
8# Test result of triple loop (too big to inline)
9TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
10            (0, 1, 0), (0, 1, 1), (0, 1, 2),
11            (0, 2, 0), (0, 2, 1), (0, 2, 2),
12
13            (1, 0, 0), (1, 0, 1), (1, 0, 2),
14            (1, 1, 0), (1, 1, 1), (1, 1, 2),
15            (1, 2, 0), (1, 2, 1), (1, 2, 2),
16
17            (2, 0, 0), (2, 0, 1), (2, 0, 2),
18            (2, 1, 0), (2, 1, 1), (2, 1, 2),
19            (2, 2, 0), (2, 2, 1), (2, 2, 2)]
20
21# Helper classes
22
23class BasicIterClass:
24    def __init__(self, n):
25        self.n = n
26        self.i = 0
27    def next(self):
28        res = self.i
29        if res >= self.n:
30            raise StopIteration
31        self.i = res + 1
32        return res
33
34class IteratingSequenceClass:
35    def __init__(self, n):
36        self.n = n
37    def __iter__(self):
38        return BasicIterClass(self.n)
39
40class SequenceClass:
41    def __init__(self, n):
42        self.n = n
43    def __getitem__(self, i):
44        if 0 <= i < self.n:
45            return i
46        else:
47            raise IndexError
48
49# Main test suite
50
51class TestCase(unittest.TestCase):
52
53    # Helper to check that an iterator returns a given sequence
54    def check_iterator(self, it, seq):
55        res = []
56        while 1:
57            try:
58                val = it.next()
59            except StopIteration:
60                break
61            res.append(val)
62        self.assertEqual(res, seq)
63
64    # Helper to check that a for loop generates a given sequence
65    def check_for_loop(self, expr, seq):
66        res = []
67        for val in expr:
68            res.append(val)
69        self.assertEqual(res, seq)
70
71    # Test basic use of iter() function
72    def test_iter_basic(self):
73        self.check_iterator(iter(range(10)), range(10))
74
75    # Test that iter(iter(x)) is the same as iter(x)
76    def test_iter_idempotency(self):
77        seq = range(10)
78        it = iter(seq)
79        it2 = iter(it)
80        self.assertTrue(it is it2)
81
82    # Test that for loops over iterators work
83    def test_iter_for_loop(self):
84        self.check_for_loop(iter(range(10)), range(10))
85
86    # Test several independent iterators over the same list
87    def test_iter_independence(self):
88        seq = range(3)
89        res = []
90        for i in iter(seq):
91            for j in iter(seq):
92                for k in iter(seq):
93                    res.append((i, j, k))
94        self.assertEqual(res, TRIPLETS)
95
96    # Test triple list comprehension using iterators
97    def test_nested_comprehensions_iter(self):
98        seq = range(3)
99        res = [(i, j, k)
100               for i in iter(seq) for j in iter(seq) for k in iter(seq)]
101        self.assertEqual(res, TRIPLETS)
102
103    # Test triple list comprehension without iterators
104    def test_nested_comprehensions_for(self):
105        seq = range(3)
106        res = [(i, j, k) for i in seq for j in seq for k in seq]
107        self.assertEqual(res, TRIPLETS)
108
109    # Test a class with __iter__ in a for loop
110    def test_iter_class_for(self):
111        self.check_for_loop(IteratingSequenceClass(10), range(10))
112
113    # Test a class with __iter__ with explicit iter()
114    def test_iter_class_iter(self):
115        self.check_iterator(iter(IteratingSequenceClass(10)), range(10))
116
117    # Test for loop on a sequence class without __iter__
118    def test_seq_class_for(self):
119        self.check_for_loop(SequenceClass(10), range(10))
120
121    # Test iter() on a sequence class without __iter__
122    def test_seq_class_iter(self):
123        self.check_iterator(iter(SequenceClass(10)), range(10))
124
125    def test_mutating_seq_class_exhausted_iter(self):
126        a = SequenceClass(5)
127        exhit = iter(a)
128        empit = iter(a)
129        for x in exhit:  # exhaust the iterator
130            next(empit)  # not exhausted
131        a.n = 7
132        self.assertEqual(list(exhit), [])
133        self.assertEqual(list(empit), [5, 6])
134        self.assertEqual(list(a), [0, 1, 2, 3, 4, 5, 6])
135
136    # Test a new_style class with __iter__ but no next() method
137    def test_new_style_iter_class(self):
138        class IterClass(object):
139            def __iter__(self):
140                return self
141        self.assertRaises(TypeError, iter, IterClass())
142
143    # Test two-argument iter() with callable instance
144    def test_iter_callable(self):
145        class C:
146            def __init__(self):
147                self.i = 0
148            def __call__(self):
149                i = self.i
150                self.i = i + 1
151                if i > 100:
152                    raise IndexError # Emergency stop
153                return i
154        self.check_iterator(iter(C(), 10), range(10))
155
156    # Test two-argument iter() with function
157    def test_iter_function(self):
158        def spam(state=[0]):
159            i = state[0]
160            state[0] = i+1
161            return i
162        self.check_iterator(iter(spam, 10), range(10))
163
164    # Test two-argument iter() with function that raises StopIteration
165    def test_iter_function_stop(self):
166        def spam(state=[0]):
167            i = state[0]
168            if i == 10:
169                raise StopIteration
170            state[0] = i+1
171            return i
172        self.check_iterator(iter(spam, 20), range(10))
173
174    # Test exception propagation through function iterator
175    def test_exception_function(self):
176        def spam(state=[0]):
177            i = state[0]
178            state[0] = i+1
179            if i == 10:
180                raise RuntimeError
181            return i
182        res = []
183        try:
184            for x in iter(spam, 20):
185                res.append(x)
186        except RuntimeError:
187            self.assertEqual(res, range(10))
188        else:
189            self.fail("should have raised RuntimeError")
190
191    # Test exception propagation through sequence iterator
192    def test_exception_sequence(self):
193        class MySequenceClass(SequenceClass):
194            def __getitem__(self, i):
195                if i == 10:
196                    raise RuntimeError
197                return SequenceClass.__getitem__(self, i)
198        res = []
199        try:
200            for x in MySequenceClass(20):
201                res.append(x)
202        except RuntimeError:
203            self.assertEqual(res, range(10))
204        else:
205            self.fail("should have raised RuntimeError")
206
207    # Test for StopIteration from __getitem__
208    def test_stop_sequence(self):
209        class MySequenceClass(SequenceClass):
210            def __getitem__(self, i):
211                if i == 10:
212                    raise StopIteration
213                return SequenceClass.__getitem__(self, i)
214        self.check_for_loop(MySequenceClass(20), range(10))
215
216    # Test a big range
217    def test_iter_big_range(self):
218        self.check_for_loop(iter(range(10000)), range(10000))
219
220    # Test an empty list
221    def test_iter_empty(self):
222        self.check_for_loop(iter([]), [])
223
224    # Test a tuple
225    def test_iter_tuple(self):
226        self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), range(10))
227
228    # Test an xrange
229    def test_iter_xrange(self):
230        self.check_for_loop(iter(xrange(10)), range(10))
231
232    # Test a string
233    def test_iter_string(self):
234        self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"])
235
236    # Test a Unicode string
237    if have_unicode:
238        def test_iter_unicode(self):
239            self.check_for_loop(iter(unicode("abcde")),
240                                [unicode("a"), unicode("b"), unicode("c"),
241                                 unicode("d"), unicode("e")])
242
243    # Test a directory
244    def test_iter_dict(self):
245        dict = {}
246        for i in range(10):
247            dict[i] = None
248        self.check_for_loop(dict, dict.keys())
249
250    # Test a file
251    def test_iter_file(self):
252        f = open(TESTFN, "w")
253        try:
254            for i in range(5):
255                f.write("%d\n" % i)
256        finally:
257            f.close()
258        f = open(TESTFN, "r")
259        try:
260            self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"])
261            self.check_for_loop(f, [])
262        finally:
263            f.close()
264            try:
265                unlink(TESTFN)
266            except OSError:
267                pass
268
269    # Test list()'s use of iterators.
270    def test_builtin_list(self):
271        self.assertEqual(list(SequenceClass(5)), range(5))
272        self.assertEqual(list(SequenceClass(0)), [])
273        self.assertEqual(list(()), [])
274        self.assertEqual(list(range(10, -1, -1)), range(10, -1, -1))
275
276        d = {"one": 1, "two": 2, "three": 3}
277        self.assertEqual(list(d), d.keys())
278
279        self.assertRaises(TypeError, list, list)
280        self.assertRaises(TypeError, list, 42)
281
282        f = open(TESTFN, "w")
283        try:
284            for i in range(5):
285                f.write("%d\n" % i)
286        finally:
287            f.close()
288        f = open(TESTFN, "r")
289        try:
290            self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"])
291            f.seek(0, 0)
292            self.assertEqual(list(f),
293                             ["0\n", "1\n", "2\n", "3\n", "4\n"])
294        finally:
295            f.close()
296            try:
297                unlink(TESTFN)
298            except OSError:
299                pass
300
301    # Test tuples()'s use of iterators.
302    def test_builtin_tuple(self):
303        self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4))
304        self.assertEqual(tuple(SequenceClass(0)), ())
305        self.assertEqual(tuple([]), ())
306        self.assertEqual(tuple(()), ())
307        self.assertEqual(tuple("abc"), ("a", "b", "c"))
308
309        d = {"one": 1, "two": 2, "three": 3}
310        self.assertEqual(tuple(d), tuple(d.keys()))
311
312        self.assertRaises(TypeError, tuple, list)
313        self.assertRaises(TypeError, tuple, 42)
314
315        f = open(TESTFN, "w")
316        try:
317            for i in range(5):
318                f.write("%d\n" % i)
319        finally:
320            f.close()
321        f = open(TESTFN, "r")
322        try:
323            self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n"))
324            f.seek(0, 0)
325            self.assertEqual(tuple(f),
326                             ("0\n", "1\n", "2\n", "3\n", "4\n"))
327        finally:
328            f.close()
329            try:
330                unlink(TESTFN)
331            except OSError:
332                pass
333
334    # Test filter()'s use of iterators.
335    def test_builtin_filter(self):
336        self.assertEqual(filter(None, SequenceClass(5)), range(1, 5))
337        self.assertEqual(filter(None, SequenceClass(0)), [])
338        self.assertEqual(filter(None, ()), ())
339        self.assertEqual(filter(None, "abc"), "abc")
340
341        d = {"one": 1, "two": 2, "three": 3}
342        self.assertEqual(filter(None, d), d.keys())
343
344        self.assertRaises(TypeError, filter, None, list)
345        self.assertRaises(TypeError, filter, None, 42)
346
347        class Boolean:
348            def __init__(self, truth):
349                self.truth = truth
350            def __nonzero__(self):
351                return self.truth
352        bTrue = Boolean(1)
353        bFalse = Boolean(0)
354
355        class Seq:
356            def __init__(self, *args):
357                self.vals = args
358            def __iter__(self):
359                class SeqIter:
360                    def __init__(self, vals):
361                        self.vals = vals
362                        self.i = 0
363                    def __iter__(self):
364                        return self
365                    def next(self):
366                        i = self.i
367                        self.i = i + 1
368                        if i < len(self.vals):
369                            return self.vals[i]
370                        else:
371                            raise StopIteration
372                return SeqIter(self.vals)
373
374        seq = Seq(*([bTrue, bFalse] * 25))
375        self.assertEqual(filter(lambda x: not x, seq), [bFalse]*25)
376        self.assertEqual(filter(lambda x: not x, iter(seq)), [bFalse]*25)
377
378    # Test max() and min()'s use of iterators.
379    def test_builtin_max_min(self):
380        self.assertEqual(max(SequenceClass(5)), 4)
381        self.assertEqual(min(SequenceClass(5)), 0)
382        self.assertEqual(max(8, -1), 8)
383        self.assertEqual(min(8, -1), -1)
384
385        d = {"one": 1, "two": 2, "three": 3}
386        self.assertEqual(max(d), "two")
387        self.assertEqual(min(d), "one")
388        self.assertEqual(max(d.itervalues()), 3)
389        self.assertEqual(min(iter(d.itervalues())), 1)
390
391        f = open(TESTFN, "w")
392        try:
393            f.write("medium line\n")
394            f.write("xtra large line\n")
395            f.write("itty-bitty line\n")
396        finally:
397            f.close()
398        f = open(TESTFN, "r")
399        try:
400            self.assertEqual(min(f), "itty-bitty line\n")
401            f.seek(0, 0)
402            self.assertEqual(max(f), "xtra large line\n")
403        finally:
404            f.close()
405            try:
406                unlink(TESTFN)
407            except OSError:
408                pass
409
410    # Test map()'s use of iterators.
411    def test_builtin_map(self):
412        self.assertEqual(map(lambda x: x+1, SequenceClass(5)), range(1, 6))
413
414        d = {"one": 1, "two": 2, "three": 3}
415        self.assertEqual(map(lambda k, d=d: (k, d[k]), d), d.items())
416        dkeys = d.keys()
417        expected = [(i < len(d) and dkeys[i] or None,
418                     i,
419                     i < len(d) and dkeys[i] or None)
420                    for i in range(5)]
421
422        # Deprecated map(None, ...)
423        with check_py3k_warnings():
424            self.assertEqual(map(None, SequenceClass(5)), range(5))
425            self.assertEqual(map(None, d), d.keys())
426            self.assertEqual(map(None, d,
427                                       SequenceClass(5),
428                                       iter(d.iterkeys())),
429                             expected)
430
431        f = open(TESTFN, "w")
432        try:
433            for i in range(10):
434                f.write("xy" * i + "\n") # line i has len 2*i+1
435        finally:
436            f.close()
437        f = open(TESTFN, "r")
438        try:
439            self.assertEqual(map(len, f), range(1, 21, 2))
440        finally:
441            f.close()
442            try:
443                unlink(TESTFN)
444            except OSError:
445                pass
446
447    # Test zip()'s use of iterators.
448    def test_builtin_zip(self):
449        self.assertEqual(zip(), [])
450        self.assertEqual(zip(*[]), [])
451        self.assertEqual(zip(*[(1, 2), 'ab']), [(1, 'a'), (2, 'b')])
452
453        self.assertRaises(TypeError, zip, None)
454        self.assertRaises(TypeError, zip, range(10), 42)
455        self.assertRaises(TypeError, zip, range(10), zip)
456
457        self.assertEqual(zip(IteratingSequenceClass(3)),
458                         [(0,), (1,), (2,)])
459        self.assertEqual(zip(SequenceClass(3)),
460                         [(0,), (1,), (2,)])
461
462        d = {"one": 1, "two": 2, "three": 3}
463        self.assertEqual(d.items(), zip(d, d.itervalues()))
464
465        # Generate all ints starting at constructor arg.
466        class IntsFrom:
467            def __init__(self, start):
468                self.i = start
469
470            def __iter__(self):
471                return self
472
473            def next(self):
474                i = self.i
475                self.i = i+1
476                return i
477
478        f = open(TESTFN, "w")
479        try:
480            f.write("a\n" "bbb\n" "cc\n")
481        finally:
482            f.close()
483        f = open(TESTFN, "r")
484        try:
485            self.assertEqual(zip(IntsFrom(0), f, IntsFrom(-100)),
486                             [(0, "a\n", -100),
487                              (1, "bbb\n", -99),
488                              (2, "cc\n", -98)])
489        finally:
490            f.close()
491            try:
492                unlink(TESTFN)
493            except OSError:
494                pass
495
496        self.assertEqual(zip(xrange(5)), [(i,) for i in range(5)])
497
498        # Classes that lie about their lengths.
499        class NoGuessLen5:
500            def __getitem__(self, i):
501                if i >= 5:
502                    raise IndexError
503                return i
504
505        class Guess3Len5(NoGuessLen5):
506            def __len__(self):
507                return 3
508
509        class Guess30Len5(NoGuessLen5):
510            def __len__(self):
511                return 30
512
513        self.assertEqual(len(Guess3Len5()), 3)
514        self.assertEqual(len(Guess30Len5()), 30)
515        self.assertEqual(zip(NoGuessLen5()), zip(range(5)))
516        self.assertEqual(zip(Guess3Len5()), zip(range(5)))
517        self.assertEqual(zip(Guess30Len5()), zip(range(5)))
518
519        expected = [(i, i) for i in range(5)]
520        for x in NoGuessLen5(), Guess3Len5(), Guess30Len5():
521            for y in NoGuessLen5(), Guess3Len5(), Guess30Len5():
522                self.assertEqual(zip(x, y), expected)
523
524    # Test reduces()'s use of iterators.
525    def test_deprecated_builtin_reduce(self):
526        with check_py3k_warnings():
527            self._test_builtin_reduce()
528
529    def _test_builtin_reduce(self):
530        from operator import add
531        self.assertEqual(reduce(add, SequenceClass(5)), 10)
532        self.assertEqual(reduce(add, SequenceClass(5), 42), 52)
533        self.assertRaises(TypeError, reduce, add, SequenceClass(0))
534        self.assertEqual(reduce(add, SequenceClass(0), 42), 42)
535        self.assertEqual(reduce(add, SequenceClass(1)), 0)
536        self.assertEqual(reduce(add, SequenceClass(1), 42), 42)
537
538        d = {"one": 1, "two": 2, "three": 3}
539        self.assertEqual(reduce(add, d), "".join(d.keys()))
540
541    @unittest.skipUnless(have_unicode, 'needs unicode support')
542    def test_unicode_join_endcase(self):
543
544        # This class inserts a Unicode object into its argument's natural
545        # iteration, in the 3rd position.
546        class OhPhooey:
547            def __init__(self, seq):
548                self.it = iter(seq)
549                self.i = 0
550
551            def __iter__(self):
552                return self
553
554            def next(self):
555                i = self.i
556                self.i = i+1
557                if i == 2:
558                    return unicode("fooled you!")
559                return self.it.next()
560
561        f = open(TESTFN, "w")
562        try:
563            f.write("a\n" + "b\n" + "c\n")
564        finally:
565            f.close()
566
567        f = open(TESTFN, "r")
568        # Nasty:  string.join(s) can't know whether unicode.join() is needed
569        # until it's seen all of s's elements.  But in this case, f's
570        # iterator cannot be restarted.  So what we're testing here is
571        # whether string.join() can manage to remember everything it's seen
572        # and pass that on to unicode.join().
573        try:
574            got = " - ".join(OhPhooey(f))
575            self.assertEqual(got, unicode("a\n - b\n - fooled you! - c\n"))
576        finally:
577            f.close()
578            try:
579                unlink(TESTFN)
580            except OSError:
581                pass
582
583    # Test iterators with 'x in y' and 'x not in y'.
584    def test_in_and_not_in(self):
585        for sc5 in IteratingSequenceClass(5), SequenceClass(5):
586            for i in range(5):
587                self.assertIn(i, sc5)
588            for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
589                self.assertNotIn(i, sc5)
590
591        self.assertRaises(TypeError, lambda: 3 in 12)
592        self.assertRaises(TypeError, lambda: 3 not in map)
593
594        d = {"one": 1, "two": 2, "three": 3, 1j: 2j}
595        for k in d:
596            self.assertIn(k, d)
597            self.assertNotIn(k, d.itervalues())
598        for v in d.values():
599            self.assertIn(v, d.itervalues())
600            self.assertNotIn(v, d)
601        for k, v in d.iteritems():
602            self.assertIn((k, v), d.iteritems())
603            self.assertNotIn((v, k), d.iteritems())
604
605        f = open(TESTFN, "w")
606        try:
607            f.write("a\n" "b\n" "c\n")
608        finally:
609            f.close()
610        f = open(TESTFN, "r")
611        try:
612            for chunk in "abc":
613                f.seek(0, 0)
614                self.assertNotIn(chunk, f)
615                f.seek(0, 0)
616                self.assertIn((chunk + "\n"), f)
617        finally:
618            f.close()
619            try:
620                unlink(TESTFN)
621            except OSError:
622                pass
623
624    # Test iterators with operator.countOf (PySequence_Count).
625    def test_countOf(self):
626        from operator import countOf
627        self.assertEqual(countOf([1,2,2,3,2,5], 2), 3)
628        self.assertEqual(countOf((1,2,2,3,2,5), 2), 3)
629        self.assertEqual(countOf("122325", "2"), 3)
630        self.assertEqual(countOf("122325", "6"), 0)
631
632        self.assertRaises(TypeError, countOf, 42, 1)
633        self.assertRaises(TypeError, countOf, countOf, countOf)
634
635        d = {"one": 3, "two": 3, "three": 3, 1j: 2j}
636        for k in d:
637            self.assertEqual(countOf(d, k), 1)
638        self.assertEqual(countOf(d.itervalues(), 3), 3)
639        self.assertEqual(countOf(d.itervalues(), 2j), 1)
640        self.assertEqual(countOf(d.itervalues(), 1j), 0)
641
642        f = open(TESTFN, "w")
643        try:
644            f.write("a\n" "b\n" "c\n" "b\n")
645        finally:
646            f.close()
647        f = open(TESTFN, "r")
648        try:
649            for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0):
650                f.seek(0, 0)
651                self.assertEqual(countOf(f, letter + "\n"), count)
652        finally:
653            f.close()
654            try:
655                unlink(TESTFN)
656            except OSError:
657                pass
658
659    # Test iterators with operator.indexOf (PySequence_Index).
660    def test_indexOf(self):
661        from operator import indexOf
662        self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0)
663        self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1)
664        self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3)
665        self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5)
666        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0)
667        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6)
668
669        self.assertEqual(indexOf("122325", "2"), 1)
670        self.assertEqual(indexOf("122325", "5"), 5)
671        self.assertRaises(ValueError, indexOf, "122325", "6")
672
673        self.assertRaises(TypeError, indexOf, 42, 1)
674        self.assertRaises(TypeError, indexOf, indexOf, indexOf)
675
676        f = open(TESTFN, "w")
677        try:
678            f.write("a\n" "b\n" "c\n" "d\n" "e\n")
679        finally:
680            f.close()
681        f = open(TESTFN, "r")
682        try:
683            fiter = iter(f)
684            self.assertEqual(indexOf(fiter, "b\n"), 1)
685            self.assertEqual(indexOf(fiter, "d\n"), 1)
686            self.assertEqual(indexOf(fiter, "e\n"), 0)
687            self.assertRaises(ValueError, indexOf, fiter, "a\n")
688        finally:
689            f.close()
690            try:
691                unlink(TESTFN)
692            except OSError:
693                pass
694
695        iclass = IteratingSequenceClass(3)
696        for i in range(3):
697            self.assertEqual(indexOf(iclass, i), i)
698        self.assertRaises(ValueError, indexOf, iclass, -1)
699
700    # Test iterators with file.writelines().
701    def test_writelines(self):
702        f = file(TESTFN, "w")
703
704        try:
705            self.assertRaises(TypeError, f.writelines, None)
706            self.assertRaises(TypeError, f.writelines, 42)
707
708            f.writelines(["1\n", "2\n"])
709            f.writelines(("3\n", "4\n"))
710            f.writelines({'5\n': None})
711            f.writelines({})
712
713            # Try a big chunk too.
714            class Iterator:
715                def __init__(self, start, finish):
716                    self.start = start
717                    self.finish = finish
718                    self.i = self.start
719
720                def next(self):
721                    if self.i >= self.finish:
722                        raise StopIteration
723                    result = str(self.i) + '\n'
724                    self.i += 1
725                    return result
726
727                def __iter__(self):
728                    return self
729
730            class Whatever:
731                def __init__(self, start, finish):
732                    self.start = start
733                    self.finish = finish
734
735                def __iter__(self):
736                    return Iterator(self.start, self.finish)
737
738            f.writelines(Whatever(6, 6+2000))
739            f.close()
740
741            f = file(TESTFN)
742            expected = [str(i) + "\n" for i in range(1, 2006)]
743            self.assertEqual(list(f), expected)
744
745        finally:
746            f.close()
747            try:
748                unlink(TESTFN)
749            except OSError:
750                pass
751
752
753    # Test iterators on RHS of unpacking assignments.
754    def test_unpack_iter(self):
755        a, b = 1, 2
756        self.assertEqual((a, b), (1, 2))
757
758        a, b, c = IteratingSequenceClass(3)
759        self.assertEqual((a, b, c), (0, 1, 2))
760
761        try:    # too many values
762            a, b = IteratingSequenceClass(3)
763        except ValueError:
764            pass
765        else:
766            self.fail("should have raised ValueError")
767
768        try:    # not enough values
769            a, b, c = IteratingSequenceClass(2)
770        except ValueError:
771            pass
772        else:
773            self.fail("should have raised ValueError")
774
775        try:    # not iterable
776            a, b, c = len
777        except TypeError:
778            pass
779        else:
780            self.fail("should have raised TypeError")
781
782        a, b, c = {1: 42, 2: 42, 3: 42}.itervalues()
783        self.assertEqual((a, b, c), (42, 42, 42))
784
785        f = open(TESTFN, "w")
786        lines = ("a\n", "bb\n", "ccc\n")
787        try:
788            for line in lines:
789                f.write(line)
790        finally:
791            f.close()
792        f = open(TESTFN, "r")
793        try:
794            a, b, c = f
795            self.assertEqual((a, b, c), lines)
796        finally:
797            f.close()
798            try:
799                unlink(TESTFN)
800            except OSError:
801                pass
802
803        (a, b), (c,) = IteratingSequenceClass(2), {42: 24}
804        self.assertEqual((a, b, c), (0, 1, 42))
805
806
807    @cpython_only
808    def test_ref_counting_behavior(self):
809        class C(object):
810            count = 0
811            def __new__(cls):
812                cls.count += 1
813                return object.__new__(cls)
814            def __del__(self):
815                cls = self.__class__
816                assert cls.count > 0
817                cls.count -= 1
818        x = C()
819        self.assertEqual(C.count, 1)
820        del x
821        self.assertEqual(C.count, 0)
822        l = [C(), C(), C()]
823        self.assertEqual(C.count, 3)
824        try:
825            a, b = iter(l)
826        except ValueError:
827            pass
828        del l
829        self.assertEqual(C.count, 0)
830
831
832    # Make sure StopIteration is a "sink state".
833    # This tests various things that weren't sink states in Python 2.2.1,
834    # plus various things that always were fine.
835
836    def test_sinkstate_list(self):
837        # This used to fail
838        a = range(5)
839        b = iter(a)
840        self.assertEqual(list(b), range(5))
841        a.extend(range(5, 10))
842        self.assertEqual(list(b), [])
843
844    def test_sinkstate_tuple(self):
845        a = (0, 1, 2, 3, 4)
846        b = iter(a)
847        self.assertEqual(list(b), range(5))
848        self.assertEqual(list(b), [])
849
850    def test_sinkstate_string(self):
851        a = "abcde"
852        b = iter(a)
853        self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e'])
854        self.assertEqual(list(b), [])
855
856    def test_sinkstate_sequence(self):
857        # This used to fail
858        a = SequenceClass(5)
859        b = iter(a)
860        self.assertEqual(list(b), range(5))
861        a.n = 10
862        self.assertEqual(list(b), [])
863
864    def test_sinkstate_callable(self):
865        # This used to fail
866        def spam(state=[0]):
867            i = state[0]
868            state[0] = i+1
869            if i == 10:
870                raise AssertionError, "shouldn't have gotten this far"
871            return i
872        b = iter(spam, 5)
873        self.assertEqual(list(b), range(5))
874        self.assertEqual(list(b), [])
875
876    def test_sinkstate_dict(self):
877        # XXX For a more thorough test, see towards the end of:
878        # http://mail.python.org/pipermail/python-dev/2002-July/026512.html
879        a = {1:1, 2:2, 0:0, 4:4, 3:3}
880        for b in iter(a), a.iterkeys(), a.iteritems(), a.itervalues():
881            b = iter(a)
882            self.assertEqual(len(list(b)), 5)
883            self.assertEqual(list(b), [])
884
885    def test_sinkstate_yield(self):
886        def gen():
887            for i in range(5):
888                yield i
889        b = gen()
890        self.assertEqual(list(b), range(5))
891        self.assertEqual(list(b), [])
892
893    def test_sinkstate_range(self):
894        a = xrange(5)
895        b = iter(a)
896        self.assertEqual(list(b), range(5))
897        self.assertEqual(list(b), [])
898
899    def test_sinkstate_enumerate(self):
900        a = range(5)
901        e = enumerate(a)
902        b = iter(e)
903        self.assertEqual(list(b), zip(range(5), range(5)))
904        self.assertEqual(list(b), [])
905
906    def test_3720(self):
907        # Avoid a crash, when an iterator deletes its next() method.
908        class BadIterator(object):
909            def __iter__(self):
910                return self
911            def next(self):
912                del BadIterator.next
913                return 1
914
915        try:
916            for i in BadIterator() :
917                pass
918        except TypeError:
919            pass
920
921    def test_extending_list_with_iterator_does_not_segfault(self):
922        # The code to extend a list with an iterator has a fair
923        # amount of nontrivial logic in terms of guessing how
924        # much memory to allocate in advance, "stealing" refs,
925        # and then shrinking at the end.  This is a basic smoke
926        # test for that scenario.
927        def gen():
928            for i in range(500):
929                yield i
930        lst = [0] * 500
931        for i in range(240):
932            lst.pop(0)
933        lst.extend(gen())
934        self.assertEqual(len(lst), 760)
935
936    def test_free_after_iterating(self):
937        check_free_after_iterating(self, iter, SequenceClass, (0,))
938
939
940def test_main():
941    run_unittest(TestCase)
942
943
944if __name__ == "__main__":
945    test_main()
946