1
2import os
3import pickle
4import sys
5
6if sys.version_info[0] < 3 :
7    try:
8        import cPickle
9    except ImportError:
10        cPickle = None
11else :
12    cPickle = None
13
14import unittest
15
16from test_all import db, test_support, get_new_environment_path, get_new_database_path
17
18#----------------------------------------------------------------------
19
20class pickleTestCase(unittest.TestCase):
21    """Verify that DBError can be pickled and unpickled"""
22    db_name = 'test-dbobj.db'
23
24    def setUp(self):
25        self.homeDir = get_new_environment_path()
26
27    def tearDown(self):
28        if hasattr(self, 'db'):
29            del self.db
30        if hasattr(self, 'env'):
31            del self.env
32        test_support.rmtree(self.homeDir)
33
34    def _base_test_pickle_DBError(self, pickle):
35        self.env = db.DBEnv()
36        self.env.open(self.homeDir, db.DB_CREATE | db.DB_INIT_MPOOL)
37        self.db = db.DB(self.env)
38        self.db.open(self.db_name, db.DB_HASH, db.DB_CREATE)
39        self.db.put('spam', 'eggs')
40        self.assertEqual(self.db['spam'], 'eggs')
41        try:
42            self.db.put('spam', 'ham', flags=db.DB_NOOVERWRITE)
43        except db.DBError, egg:
44            pickledEgg = pickle.dumps(egg)
45            #print repr(pickledEgg)
46            rottenEgg = pickle.loads(pickledEgg)
47            if rottenEgg.args != egg.args or type(rottenEgg) != type(egg):
48                raise Exception, (rottenEgg, '!=', egg)
49        else:
50            raise Exception, "where's my DBError exception?!?"
51
52        self.db.close()
53        self.env.close()
54
55    def test01_pickle_DBError(self):
56        self._base_test_pickle_DBError(pickle=pickle)
57
58    if cPickle:
59        def test02_cPickle_DBError(self):
60            self._base_test_pickle_DBError(pickle=cPickle)
61
62#----------------------------------------------------------------------
63
64def test_suite():
65    return unittest.makeSuite(pickleTestCase)
66
67if __name__ == '__main__':
68    unittest.main(defaultTest='test_suite')
69