1#!/usr/bin/python -u
2
3import os, sys, re, tempfile
4from optparse import OptionParser
5import common
6from autotest_lib.client.common_lib import utils
7from autotest_lib.database import database_connection
8
9MIGRATE_TABLE = 'migrate_info'
10
11_AUTODIR = os.path.join(os.path.dirname(__file__), '..')
12_MIGRATIONS_DIRS = {
13    'AUTOTEST_WEB': os.path.join(_AUTODIR, 'frontend', 'migrations'),
14    'TKO': os.path.join(_AUTODIR, 'tko', 'migrations'),
15    'AUTOTEST_SERVER_DB': os.path.join(_AUTODIR, 'database',
16                                      'server_db_migrations'),
17}
18_DEFAULT_MIGRATIONS_DIR = 'migrations' # use CWD
19
20class Migration(object):
21    """Represents a database migration."""
22    _UP_ATTRIBUTES = ('migrate_up', 'UP_SQL')
23    _DOWN_ATTRIBUTES = ('migrate_down', 'DOWN_SQL')
24
25    def __init__(self, name, version, module):
26        self.name = name
27        self.version = version
28        self.module = module
29        self._check_attributes(self._UP_ATTRIBUTES)
30        self._check_attributes(self._DOWN_ATTRIBUTES)
31
32
33    @classmethod
34    def from_file(cls, filename):
35        """Instantiates a Migration from a file.
36
37        @param filename: Name of a migration file.
38
39        @return An instantiated Migration object.
40
41        """
42        version = int(filename[:3])
43        name = filename[:-3]
44        module = __import__(name, globals(), locals(), [])
45        return cls(name, version, module)
46
47
48    def _check_attributes(self, attributes):
49        method_name, sql_name = attributes
50        assert (hasattr(self.module, method_name) or
51                hasattr(self.module, sql_name))
52
53
54    def _execute_migration(self, attributes, manager):
55        method_name, sql_name = attributes
56        method = getattr(self.module, method_name, None)
57        if method:
58            assert callable(method)
59            method(manager)
60        else:
61            sql = getattr(self.module, sql_name)
62            assert isinstance(sql, basestring)
63            manager.execute_script(sql)
64
65
66    def migrate_up(self, manager):
67        """Performs an up migration (to a newer version).
68
69        @param manager: A MigrationManager object.
70
71        """
72        self._execute_migration(self._UP_ATTRIBUTES, manager)
73
74
75    def migrate_down(self, manager):
76        """Performs a down migration (to an older version).
77
78        @param manager: A MigrationManager object.
79
80        """
81        self._execute_migration(self._DOWN_ATTRIBUTES, manager)
82
83
84class MigrationManager(object):
85    """Managest database migrations."""
86    connection = None
87    cursor = None
88    migrations_dir = None
89
90    def __init__(self, database_connection, migrations_dir=None, force=False):
91        self._database = database_connection
92        self.force = force
93        # A boolean, this will only be set to True if this migration should be
94        # simulated rather than actually taken. For use with migrations that
95        # may make destructive queries
96        self.simulate = False
97        self._set_migrations_dir(migrations_dir)
98
99
100    def _set_migrations_dir(self, migrations_dir=None):
101        config_section = self._config_section()
102        if migrations_dir is None:
103            migrations_dir = os.path.abspath(
104                _MIGRATIONS_DIRS.get(config_section, _DEFAULT_MIGRATIONS_DIR))
105        self.migrations_dir = migrations_dir
106        sys.path.append(migrations_dir)
107        assert os.path.exists(migrations_dir), migrations_dir + " doesn't exist"
108
109
110    def _config_section(self):
111        return self._database.global_config_section
112
113
114    def get_db_name(self):
115        """Gets the database name."""
116        return self._database.get_database_info()['db_name']
117
118
119    def execute(self, query, *parameters):
120        """Executes a database query.
121
122        @param query: The query to execute.
123        @param parameters: Associated parameters for the query.
124
125        @return The result of the query.
126
127        """
128        return self._database.execute(query, parameters)
129
130
131    def execute_script(self, script):
132        """Executes a set of database queries.
133
134        @param script: A string of semicolon-separated queries.
135
136        """
137        sql_statements = [statement.strip()
138                          for statement in script.split(';')
139                          if statement.strip()]
140        for statement in sql_statements:
141            self.execute(statement)
142
143
144    def check_migrate_table_exists(self):
145        """Checks whether the migration table exists."""
146        try:
147            self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
148            return True
149        except self._database.DatabaseError, exc:
150            # we can't check for more specifics due to differences between DB
151            # backends (we can't even check for a subclass of DatabaseError)
152            return False
153
154
155    def create_migrate_table(self):
156        """Creates the migration table."""
157        if not self.check_migrate_table_exists():
158            self.execute("CREATE TABLE %s (`version` integer)" %
159                         MIGRATE_TABLE)
160        else:
161            self.execute("DELETE FROM %s" % MIGRATE_TABLE)
162        self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE)
163        assert self._database.rowcount == 1
164
165
166    def set_db_version(self, version):
167        """Sets the database version.
168
169        @param version: The version to which to set the database.
170
171        """
172        assert isinstance(version, int)
173        self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE,
174                     version)
175        assert self._database.rowcount == 1
176
177
178    def get_db_version(self):
179        """Gets the database version.
180
181        @return The database version.
182
183        """
184        if not self.check_migrate_table_exists():
185            return 0
186        rows = self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
187        if len(rows) == 0:
188            return 0
189        assert len(rows) == 1 and len(rows[0]) == 1
190        return rows[0][0]
191
192
193    def get_migrations(self, minimum_version=None, maximum_version=None):
194        """Gets the list of migrations to perform.
195
196        @param minimum_version: The minimum database version.
197        @param maximum_version: The maximum database version.
198
199        @return A list of Migration objects.
200
201        """
202        migrate_files = [filename for filename
203                         in os.listdir(self.migrations_dir)
204                         if re.match(r'^\d\d\d_.*\.py$', filename)]
205        migrate_files.sort()
206        migrations = [Migration.from_file(filename)
207                      for filename in migrate_files]
208        if minimum_version is not None:
209            migrations = [migration for migration in migrations
210                          if migration.version >= minimum_version]
211        if maximum_version is not None:
212            migrations = [migration for migration in migrations
213                          if migration.version <= maximum_version]
214        return migrations
215
216
217    def do_migration(self, migration, migrate_up=True):
218        """Performs a migration.
219
220        @param migration: The Migration to perform.
221        @param migrate_up: Whether to migrate up (if not, then migrates down).
222
223        """
224        print 'Applying migration %s' % migration.name, # no newline
225        if migrate_up:
226            print 'up'
227            assert self.get_db_version() == migration.version - 1
228            migration.migrate_up(self)
229            new_version = migration.version
230        else:
231            print 'down'
232            assert self.get_db_version() == migration.version
233            migration.migrate_down(self)
234            new_version = migration.version - 1
235        self.set_db_version(new_version)
236
237
238    def migrate_to_version(self, version):
239        """Performs a migration to a specified version.
240
241        @param version: The version to which to migrate the database.
242
243        """
244        current_version = self.get_db_version()
245        if current_version == 0 and self._config_section() == 'AUTOTEST_WEB':
246            self._migrate_from_base()
247            current_version = self.get_db_version()
248
249        if current_version < version:
250            lower, upper = current_version, version
251            migrate_up = True
252        else:
253            lower, upper = version, current_version
254            migrate_up = False
255
256        migrations = self.get_migrations(lower + 1, upper)
257        if not migrate_up:
258            migrations.reverse()
259        for migration in migrations:
260            self.do_migration(migration, migrate_up)
261
262        assert self.get_db_version() == version
263        print 'At version', version
264
265
266    def _migrate_from_base(self):
267        """Initialize the AFE database.
268        """
269        self.confirm_initialization()
270
271        migration_script = utils.read_file(
272                os.path.join(os.path.dirname(__file__), 'schema_129.sql'))
273        migration_script = migration_script % (
274                dict(username=self._database.get_database_info()['username']))
275        self.execute_script(migration_script)
276
277        self.create_migrate_table()
278        self.set_db_version(129)
279
280
281    def confirm_initialization(self):
282        """Confirms with the user that we should initialize the database.
283
284        @raises Exception, if the user chooses to abort the migration.
285
286        """
287        if not self.force:
288            response = raw_input(
289                'Your %s database does not appear to be initialized.  Do you '
290                'want to recreate it (this will result in loss of any existing '
291                'data) (yes/No)? ' % self.get_db_name())
292            if response != 'yes':
293                raise Exception('User has chosen to abort migration')
294
295
296    def get_latest_version(self):
297        """Gets the latest database version."""
298        migrations = self.get_migrations()
299        return migrations[-1].version
300
301
302    def migrate_to_latest(self):
303        """Migrates the database to the latest version."""
304        latest_version = self.get_latest_version()
305        self.migrate_to_version(latest_version)
306
307
308    def initialize_test_db(self):
309        """Initializes a test database."""
310        db_name = self.get_db_name()
311        test_db_name = 'test_' + db_name
312        # first, connect to no DB so we can create a test DB
313        self._database.connect(db_name='')
314        print 'Creating test DB', test_db_name
315        self.execute('CREATE DATABASE ' + test_db_name)
316        self._database.disconnect()
317        # now connect to the test DB
318        self._database.connect(db_name=test_db_name)
319
320
321    def remove_test_db(self):
322        """Removes a test database."""
323        print 'Removing test DB'
324        self.execute('DROP DATABASE ' + self.get_db_name())
325        # reset connection back to real DB
326        self._database.disconnect()
327        self._database.connect()
328
329
330    def get_mysql_args(self):
331        """Returns the mysql arguments as a string."""
332        return ('-u %(username)s -p%(password)s -h %(host)s %(db_name)s' %
333                self._database.get_database_info())
334
335
336    def migrate_to_version_or_latest(self, version):
337        """Migrates to either a specified version, or the latest version.
338
339        @param version: The version to which to migrate the database,
340            or None in order to migrate to the latest version.
341
342        """
343        if version is None:
344            self.migrate_to_latest()
345        else:
346            self.migrate_to_version(version)
347
348
349    def do_sync_db(self, version=None):
350        """Migrates the database.
351
352        @param version: The version to which to migrate the database.
353
354        """
355        print 'Migration starting for database', self.get_db_name()
356        self.migrate_to_version_or_latest(version)
357        print 'Migration complete'
358
359
360    def test_sync_db(self, version=None):
361        """Create a fresh database and run all migrations on it.
362
363        @param version: The version to which to migrate the database.
364
365        """
366        self.initialize_test_db()
367        try:
368            print 'Starting migration test on DB', self.get_db_name()
369            self.migrate_to_version_or_latest(version)
370            # show schema to the user
371            os.system('mysqldump %s --no-data=true '
372                      '--add-drop-table=false' %
373                      self.get_mysql_args())
374        finally:
375            self.remove_test_db()
376        print 'Test finished successfully'
377
378
379    def simulate_sync_db(self, version=None):
380        """Creates a fresh DB, copies existing DB to it, then synchronizes it.
381
382        @param version: The version to which to migrate the database.
383
384        """
385        db_version = self.get_db_version()
386        # don't do anything if we're already at the latest version
387        if db_version == self.get_latest_version():
388            print 'Skipping simulation, already at latest version'
389            return
390        # get existing data
391        self.initialize_and_fill_test_db()
392        try:
393            print 'Starting migration test on DB', self.get_db_name()
394            self.migrate_to_version_or_latest(version)
395        finally:
396            self.remove_test_db()
397        print 'Test finished successfully'
398
399
400    def initialize_and_fill_test_db(self):
401        """Initializes and fills up a test database."""
402        print 'Dumping existing data'
403        dump_fd, dump_file = tempfile.mkstemp('.migrate_dump')
404        os.system('mysqldump %s >%s' %
405                  (self.get_mysql_args(), dump_file))
406        # fill in test DB
407        self.initialize_test_db()
408        print 'Filling in test DB'
409        os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file))
410        os.close(dump_fd)
411        os.remove(dump_file)
412
413
414USAGE = """\
415%s [options] sync|test|simulate|safesync [version]
416Options:
417    -d --database   Which database to act on
418    -f --force      Don't ask for confirmation
419    --debug         Print all DB queries"""\
420    % sys.argv[0]
421
422
423def main():
424    """Main function for the migration script."""
425    parser = OptionParser()
426    parser.add_option("-d", "--database",
427                      help="which database to act on",
428                      dest="database",
429                      default="AUTOTEST_WEB")
430    parser.add_option("-f", "--force", help="don't ask for confirmation",
431                      action="store_true")
432    parser.add_option('--debug', help='print all DB queries',
433                      action='store_true')
434    (options, args) = parser.parse_args()
435    manager = get_migration_manager(db_name=options.database,
436                                    debug=options.debug, force=options.force)
437
438    if len(args) > 0:
439        if len(args) > 1:
440            version = int(args[1])
441        else:
442            version = None
443        if args[0] == 'sync':
444            manager.do_sync_db(version)
445        elif args[0] == 'test':
446            manager.simulate=True
447            manager.test_sync_db(version)
448        elif args[0] == 'simulate':
449            manager.simulate=True
450            manager.simulate_sync_db(version)
451        elif args[0] == 'safesync':
452            print 'Simluating migration'
453            manager.simulate=True
454            manager.simulate_sync_db(version)
455            print 'Performing real migration'
456            manager.simulate=False
457            manager.do_sync_db(version)
458        else:
459            print USAGE
460        return
461
462    print USAGE
463
464
465def get_migration_manager(db_name, debug, force):
466    """Creates a MigrationManager object.
467
468    @param db_name: The database name.
469    @param debug: Whether to print debug messages.
470    @param force: Whether to force migration without asking for confirmation.
471
472    @return A created MigrationManager object.
473
474    """
475    database = database_connection.DatabaseConnection(db_name)
476    database.debug = debug
477    database.reconnect_enabled = False
478    database.connect()
479    return MigrationManager(database, force=force)
480
481
482if __name__ == '__main__':
483    main()
484