1"""TestCases for distributed transactions.
2"""
3
4import os
5import unittest
6
7from test_all import db, test_support, get_new_environment_path, \
8        get_new_database_path
9
10from test_all import verbose
11
12#----------------------------------------------------------------------
13
14class DBTxn_distributed(unittest.TestCase):
15    num_txns=1234
16    nosync=True
17    must_open_db=False
18    def _create_env(self, must_open_db) :
19        self.dbenv = db.DBEnv()
20        self.dbenv.set_tx_max(self.num_txns)
21        self.dbenv.set_lk_max_lockers(self.num_txns*2)
22        self.dbenv.set_lk_max_locks(self.num_txns*2)
23        self.dbenv.set_lk_max_objects(self.num_txns*2)
24        if self.nosync :
25            self.dbenv.set_flags(db.DB_TXN_NOSYNC,True)
26        self.dbenv.open(self.homeDir, db.DB_CREATE | db.DB_THREAD |
27                db.DB_RECOVER |
28                db.DB_INIT_TXN | db.DB_INIT_LOG | db.DB_INIT_MPOOL |
29                db.DB_INIT_LOCK, 0666)
30        self.db = db.DB(self.dbenv)
31        self.db.set_re_len(db.DB_GID_SIZE)
32        if must_open_db :
33            txn=self.dbenv.txn_begin()
34            self.db.open(self.filename,
35                    db.DB_QUEUE, db.DB_CREATE | db.DB_THREAD, 0666,
36                    txn=txn)
37            txn.commit()
38
39    def setUp(self) :
40        self.homeDir = get_new_environment_path()
41        self.filename = "test"
42        return self._create_env(must_open_db=True)
43
44    def _destroy_env(self):
45        if self.nosync or (db.version()[:2] == (4,6)):  # Known bug
46            self.dbenv.log_flush()
47        self.db.close()
48        self.dbenv.close()
49
50    def tearDown(self):
51        self._destroy_env()
52        test_support.rmtree(self.homeDir)
53
54    def _recreate_env(self,must_open_db) :
55        self._destroy_env()
56        self._create_env(must_open_db)
57
58    def test01_distributed_transactions(self) :
59        txns=set()
60        adapt = lambda x : x
61        import sys
62        if sys.version_info[0] >= 3 :
63            adapt = lambda x : bytes(x, "ascii")
64    # Create transactions, "prepare" them, and
65    # let them be garbage collected.
66        for i in xrange(self.num_txns) :
67            txn = self.dbenv.txn_begin()
68            gid = "%%%dd" %db.DB_GID_SIZE
69            gid = adapt(gid %i)
70            self.db.put(i, gid, txn=txn, flags=db.DB_APPEND)
71            txns.add(gid)
72            txn.prepare(gid)
73        del txn
74
75        self._recreate_env(self.must_open_db)
76
77    # Get "to be recovered" transactions but
78    # let them be garbage collected.
79        recovered_txns=self.dbenv.txn_recover()
80        self.assertEqual(self.num_txns,len(recovered_txns))
81        for gid,txn in recovered_txns :
82            self.assertIn(gid, txns)
83        del txn
84        del recovered_txns
85
86        self._recreate_env(self.must_open_db)
87
88    # Get "to be recovered" transactions. Commit, abort and
89    # discard them.
90        recovered_txns=self.dbenv.txn_recover()
91        self.assertEqual(self.num_txns,len(recovered_txns))
92        discard_txns=set()
93        committed_txns=set()
94        state=0
95        for gid,txn in recovered_txns :
96            if state==0 or state==1:
97                committed_txns.add(gid)
98                txn.commit()
99            elif state==2 :
100                txn.abort()
101            elif state==3 :
102                txn.discard()
103                discard_txns.add(gid)
104                state=-1
105            state+=1
106        del txn
107        del recovered_txns
108
109        self._recreate_env(self.must_open_db)
110
111    # Verify the discarded transactions are still
112    # around, and dispose them.
113        recovered_txns=self.dbenv.txn_recover()
114        self.assertEqual(len(discard_txns),len(recovered_txns))
115        for gid,txn in recovered_txns :
116            txn.abort()
117        del txn
118        del recovered_txns
119
120        self._recreate_env(must_open_db=True)
121
122    # Be sure there are not pending transactions.
123    # Check also database size.
124        recovered_txns=self.dbenv.txn_recover()
125        self.assertEqual(len(recovered_txns), 0)
126        self.assertEqual(len(committed_txns),self.db.stat()["nkeys"])
127
128class DBTxn_distributedSYNC(DBTxn_distributed):
129    nosync=False
130
131class DBTxn_distributed_must_open_db(DBTxn_distributed):
132    must_open_db=True
133
134class DBTxn_distributedSYNC_must_open_db(DBTxn_distributed):
135    nosync=False
136    must_open_db=True
137
138#----------------------------------------------------------------------
139
140def test_suite():
141    suite = unittest.TestSuite()
142    if db.version() >= (4,5) :
143        suite.addTest(unittest.makeSuite(DBTxn_distributed))
144        suite.addTest(unittest.makeSuite(DBTxn_distributedSYNC))
145    if db.version() >= (4,6) :
146        suite.addTest(unittest.makeSuite(DBTxn_distributed_must_open_db))
147        suite.addTest(unittest.makeSuite(DBTxn_distributedSYNC_must_open_db))
148    return suite
149
150
151if __name__ == '__main__':
152    unittest.main(defaultTest='test_suite')
153