1#!/usr/bin/python
2
3import unittest, tempfile, os
4import common
5import MySQLdb
6from autotest_lib.client.common_lib import global_config
7from autotest_lib.database import database_connection, migrate
8
9# Which section of the global config to pull info from.  We won't actually use
10# that DB, we'll use the corresponding test DB (test_<db name>).
11CONFIG_DB = 'AUTOTEST_WEB'
12
13NUM_MIGRATIONS = 3
14
15class DummyMigration(object):
16    """\
17    Dummy migration class that records all migrations done in a class
18    varaible.
19    """
20
21    migrations_done = []
22
23    def __init__(self, version):
24        self.version = version
25        self.name = '%03d_test' % version
26
27
28    @classmethod
29    def get_migrations_done(cls):
30        return cls.migrations_done
31
32
33    @classmethod
34    def clear_migrations_done(cls):
35        cls.migrations_done = []
36
37
38    @classmethod
39    def do_migration(cls, version, direction):
40        cls.migrations_done.append((version, direction))
41
42
43    def migrate_up(self, manager):
44        self.do_migration(self.version, 'up')
45        if self.version == 1:
46            manager.create_migrate_table()
47
48
49    def migrate_down(self, manager):
50        self.do_migration(self.version, 'down')
51
52
53MIGRATIONS = [DummyMigration(n) for n in xrange(1, NUM_MIGRATIONS + 1)]
54
55
56class TestableMigrationManager(migrate.MigrationManager):
57    def _set_migrations_dir(self, migrations_dir=None):
58        pass
59
60
61    def get_migrations(self, minimum_version=None, maximum_version=None):
62        minimum_version = minimum_version or 1
63        maximum_version = maximum_version or len(MIGRATIONS)
64        return MIGRATIONS[minimum_version-1:maximum_version]
65
66
67class MigrateManagerTest(unittest.TestCase):
68    def setUp(self):
69        self._database = (
70            database_connection.DatabaseConnection.get_test_database())
71        self._database.connect()
72        self.manager = TestableMigrationManager(self._database)
73        DummyMigration.clear_migrations_done()
74
75
76    def tearDown(self):
77        self._database.disconnect()
78
79
80    def test_sync(self):
81        self.manager.do_sync_db()
82        self.assertEquals(self.manager.get_db_version(), NUM_MIGRATIONS)
83        self.assertEquals(DummyMigration.get_migrations_done(),
84                          [(1, 'up'), (2, 'up'), (3, 'up')])
85
86        DummyMigration.clear_migrations_done()
87        self.manager.do_sync_db(0)
88        self.assertEquals(self.manager.get_db_version(), 0)
89        self.assertEquals(DummyMigration.get_migrations_done(),
90                          [(3, 'down'), (2, 'down'), (1, 'down')])
91
92
93    def test_sync_one_by_one(self):
94        for version in xrange(1, NUM_MIGRATIONS + 1):
95            self.manager.do_sync_db(version)
96            self.assertEquals(self.manager.get_db_version(),
97                              version)
98            self.assertEquals(
99                DummyMigration.get_migrations_done()[-1],
100                (version, 'up'))
101
102        for version in xrange(NUM_MIGRATIONS - 1, -1, -1):
103            self.manager.do_sync_db(version)
104            self.assertEquals(self.manager.get_db_version(),
105                              version)
106            self.assertEquals(
107                DummyMigration.get_migrations_done()[-1],
108                (version + 1, 'down'))
109
110
111    def test_null_sync(self):
112        self.manager.do_sync_db()
113        DummyMigration.clear_migrations_done()
114        self.manager.do_sync_db()
115        self.assertEquals(DummyMigration.get_migrations_done(), [])
116
117
118class DummyMigrationManager(object):
119    def __init__(self):
120        self.calls = []
121
122
123    def execute_script(self, script):
124        self.calls.append(script)
125
126
127class MigrationTest(unittest.TestCase):
128    def setUp(self):
129        self.manager = DummyMigrationManager()
130
131
132    def _do_migration(self, migration_module):
133        migration = migrate.Migration('name', 1, migration_module)
134        migration.migrate_up(self.manager)
135        migration.migrate_down(self.manager)
136
137        self.assertEquals(self.manager.calls, ['foo', 'bar'])
138
139
140    def test_migration_with_methods(self):
141        class DummyMigration(object):
142            @staticmethod
143            def migrate_up(manager):
144                manager.execute_script('foo')
145
146
147            @staticmethod
148            def migrate_down(manager):
149                manager.execute_script('bar')
150
151        self._do_migration(DummyMigration)
152
153
154    def test_migration_with_strings(self):
155        class DummyMigration(object):
156            UP_SQL = 'foo'
157            DOWN_SQL = 'bar'
158
159        self._do_migration(DummyMigration)
160
161
162if __name__ == '__main__':
163    unittest.main()
164