1import pickle
2import struct
3from cStringIO import StringIO
4
5from test import test_support
6
7from test.pickletester import (AbstractUnpickleTests,
8                               AbstractPickleTests,
9                               AbstractPickleModuleTests,
10                               AbstractPersistentPicklerTests,
11                               AbstractPicklerUnpicklerObjectTests,
12                               BigmemPickleTests)
13
14class PickleTests(AbstractUnpickleTests, AbstractPickleTests,
15                  AbstractPickleModuleTests):
16
17    def dumps(self, arg, proto=0, fast=0):
18        # Ignore fast
19        return pickle.dumps(arg, proto)
20
21    def loads(self, buf):
22        # Ignore fast
23        return pickle.loads(buf)
24
25    module = pickle
26    error = KeyError
27    bad_stack_errors = (IndexError,)
28    bad_mark_errors = (IndexError, pickle.UnpicklingError,
29                       TypeError, AttributeError, EOFError)
30    truncated_errors = (pickle.UnpicklingError, EOFError,
31                        AttributeError, ValueError,
32                        struct.error, IndexError, ImportError,
33                        TypeError, KeyError)
34
35class UnpicklerTests(AbstractUnpickleTests):
36
37    error = KeyError
38    bad_stack_errors = (IndexError,)
39    bad_mark_errors = (IndexError, pickle.UnpicklingError,
40                       TypeError, AttributeError, EOFError)
41    truncated_errors = (pickle.UnpicklingError, EOFError,
42                        AttributeError, ValueError,
43                        struct.error, IndexError, ImportError,
44                        TypeError, KeyError)
45
46    def loads(self, buf):
47        f = StringIO(buf)
48        u = pickle.Unpickler(f)
49        return u.load()
50
51class PicklerTests(AbstractPickleTests):
52
53    def dumps(self, arg, proto=0, fast=0):
54        f = StringIO()
55        p = pickle.Pickler(f, proto)
56        if fast:
57            p.fast = fast
58        p.dump(arg)
59        f.seek(0)
60        return f.read()
61
62    def loads(self, buf):
63        f = StringIO(buf)
64        u = pickle.Unpickler(f)
65        return u.load()
66
67class PersPicklerTests(AbstractPersistentPicklerTests):
68
69    def dumps(self, arg, proto=0, fast=0):
70        class PersPickler(pickle.Pickler):
71            def persistent_id(subself, obj):
72                return self.persistent_id(obj)
73        f = StringIO()
74        p = PersPickler(f, proto)
75        if fast:
76            p.fast = fast
77        p.dump(arg)
78        f.seek(0)
79        return f.read()
80
81    def loads(self, buf):
82        class PersUnpickler(pickle.Unpickler):
83            def persistent_load(subself, obj):
84                return self.persistent_load(obj)
85        f = StringIO(buf)
86        u = PersUnpickler(f)
87        return u.load()
88
89class PicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
90
91    pickler_class = pickle.Pickler
92    unpickler_class = pickle.Unpickler
93
94class PickleBigmemPickleTests(BigmemPickleTests):
95
96    def dumps(self, arg, proto=0, fast=0):
97        # Ignore fast
98        return pickle.dumps(arg, proto)
99
100    def loads(self, buf):
101        # Ignore fast
102        return pickle.loads(buf)
103
104
105def test_main():
106    test_support.run_unittest(
107        PickleTests,
108        UnpicklerTests,
109        PicklerTests,
110        PersPicklerTests,
111        PicklerUnpicklerObjectTests,
112        PickleBigmemPickleTests,
113    )
114    test_support.run_doctest(pickle)
115
116if __name__ == "__main__":
117    test_main()
118