1#!/usr/bin/python
2#pylint: disable-msg=C0111
3
4# Copyright (c) 2014 The Chromium OS Authors. All rights reserved.
5# Use of this source code is governed by a BSD-style license that can be
6# found in the LICENSE file.
7
8import abc
9import os
10
11import common
12
13from autotest_lib.database import database_connection
14from autotest_lib.frontend import setup_django_environment
15from autotest_lib.frontend.afe import frontend_test_utils
16from autotest_lib.frontend.afe import models
17from autotest_lib.frontend.afe import rdb_model_extensions as rdb_models
18from autotest_lib.scheduler import monitor_db
19from autotest_lib.scheduler import query_managers
20from autotest_lib.scheduler import scheduler_lib
21from autotest_lib.scheduler import scheduler_models
22from autotest_lib.scheduler import rdb_hosts
23from autotest_lib.scheduler import rdb_requests
24from autotest_lib.server.cros import provision
25
26
27# Set for verbose table creation output.
28_DEBUG = False
29DEFAULT_ACLS = ['Everyone', 'my_acl']
30DEFAULT_DEPS = ['a', 'b']
31DEFAULT_USER = 'system'
32
33
34def get_default_job_params():
35    return {'deps': DEFAULT_DEPS, 'user': DEFAULT_USER, 'acls': DEFAULT_ACLS,
36            'priority': 0, 'parent_job_id': 0}
37
38
39def get_default_host_params():
40    return {'deps': DEFAULT_DEPS, 'acls': DEFAULT_ACLS}
41
42
43class FakeHost(rdb_hosts.RDBHost):
44    """Fake host to use in unittests."""
45
46    def __init__(self, hostname, host_id, **kwargs):
47        kwargs.update({'hostname': hostname, 'id': host_id})
48        kwargs = rdb_models.AbstractHostModel.provide_default_values(
49                kwargs)
50        super(FakeHost, self).__init__(**kwargs)
51
52
53def wire_format_response_map(response_map):
54    wire_formatted_map = {}
55    for request, response in response_map.iteritems():
56        wire_formatted_map[request] = [reply.wire_format()
57                                       for reply in response]
58    return wire_formatted_map
59
60
61class DBHelper(object):
62    """Utility class for updating the database."""
63
64    def __init__(self):
65        """Initialized django so it uses an in memory SQLite database."""
66        self.database = (
67            database_connection.TranslatingDatabase.get_test_database(
68                translators=scheduler_lib._DB_TRANSLATORS))
69        self.database.connect(db_type='django')
70        self.database.debug = _DEBUG
71
72
73    @classmethod
74    def get_labels(cls, **kwargs):
75        """Get a label queryset based on the kwargs."""
76        return models.Label.objects.filter(**kwargs)
77
78
79    @classmethod
80    def get_acls(cls, **kwargs):
81        """Get an aclgroup queryset based on the kwargs."""
82        return models.AclGroup.objects.filter(**kwargs)
83
84
85    @classmethod
86    def get_host(cls, **kwargs):
87        """Get a host queryset based on the kwargs."""
88        return models.Host.objects.filter(**kwargs)
89
90
91    @classmethod
92    def get_hqes(cls, **kwargs):
93        return models.HostQueueEntry.objects.filter(**kwargs)
94
95
96    @classmethod
97    def get_tasks(cls, **kwargs):
98        return models.SpecialTask.objects.filter(**kwargs)
99
100
101    @classmethod
102    def get_shard(cls, **kwargs):
103        return models.Shard.objects.filter(**kwargs)
104
105
106    @classmethod
107    def create_label(cls, name, **kwargs):
108        label = cls.get_labels(name=name, **kwargs)
109        return (models.Label.add_object(name=name, **kwargs)
110                if not label else label[0])
111
112
113    @classmethod
114    def create_user(cls, name):
115        user = models.User.objects.filter(login=name)
116        return models.User.add_object(login=name) if not user else user[0]
117
118
119    @classmethod
120    def create_special_task(cls, job_id=None, host_id=None,
121                            task=models.SpecialTask.Task.VERIFY,
122                            user='autotest-system'):
123        if job_id:
124            queue_entry = cls.get_hqes(job_id=job_id)[0]
125            host_id = queue_entry.host.id
126        else:
127            queue_entry = None
128        host = models.Host.objects.get(id=host_id)
129        owner = cls.create_user(user)
130        if not host:
131            raise ValueError('Require a host to create special tasks.')
132        return models.SpecialTask.objects.create(
133                host=host, queue_entry=queue_entry, task=task,
134                requested_by_id=owner.id)
135
136
137    @classmethod
138    def create_shard(cls, shard_hostname):
139        """Create a shard with the given hostname if one doesn't already exist.
140
141        @param shard_hostname: The hostname of the shard.
142        """
143        shard = cls.get_shard(hostname=shard_hostname)
144        return (models.Shard.objects.create(hostname=shard_hostname)
145                if not shard else shard[0])
146
147
148    @classmethod
149    def add_labels_to_host(cls, host, label_names=set([])):
150        label_objects = set([])
151        for label in label_names:
152            label_objects.add(cls.create_label(label))
153        host.labels.add(*label_objects)
154
155
156    @classmethod
157    def create_acl_group(cls, name):
158        aclgroup = cls.get_acls(name=name)
159        return (models.AclGroup.add_object(name=name)
160                if not aclgroup else aclgroup[0])
161
162
163    @classmethod
164    def add_deps_to_job(cls, job, dep_names=set([])):
165        label_objects = set([])
166        for label in dep_names:
167            label_objects.add(cls.create_label(label))
168        job.dependency_labels.add(*label_objects)
169
170
171    @classmethod
172    def assign_job_to_shard(cls, job_id, shard_hostname):
173        """Assign a job to a shard.
174
175        @param job: A job object without a shard.
176        @param shard_hostname: The hostname of a shard to assign the job.
177
178        @raises ValueError: If the job already has a shard.
179        """
180        job_filter = models.Job.objects.filter(id=job_id, shard__isnull=True)
181        if len(job_filter) != 1:
182            raise ValueError('Failed to assign job %s to shard %s' %
183                             job_filter, shard_hostname)
184        job_filter.update(shard=cls.create_shard(shard_hostname))
185
186
187    @classmethod
188    def add_host_to_aclgroup(cls, host, aclgroup_names=set([])):
189        for group_name in aclgroup_names:
190            aclgroup = cls.create_acl_group(group_name)
191            aclgroup.hosts.add(host)
192
193
194    @classmethod
195    def add_user_to_aclgroups(cls, username, aclgroup_names=set([])):
196        user = cls.create_user(username)
197        for group_name in aclgroup_names:
198            aclgroup = cls.create_acl_group(group_name)
199            aclgroup.users.add(user)
200
201
202    @classmethod
203    def create_host(cls, name, deps=set([]), acls=set([]), status='Ready',
204                 locked=0, lock_reason='', leased=0, protection=0, dirty=0):
205        """Create a host.
206
207        Also adds the appropriate labels to the host, and adds the host to the
208        required acl groups.
209
210        @param name: The hostname.
211        @param kwargs:
212            deps: The labels on the host that match job deps.
213            acls: The aclgroups this host must be a part of.
214            status: The status of the host.
215            locked: 1 if the host is locked.
216            lock_reason: non-empty string if the host is locked.
217            leased: 1 if the host is leased.
218            protection: Any protection level, such as Do Not Verify.
219            dirty: 1 if the host requires cleanup.
220
221        @return: The host object for the new host.
222        """
223        # TODO: Modify this to use the create host request once
224        # crbug.com/350995 is fixed.
225        host = models.Host.add_object(
226                hostname=name, status=status, locked=locked,
227                lock_reason=lock_reason, leased=leased,
228                protection=protection)
229        cls.add_labels_to_host(host, label_names=deps)
230        cls.add_host_to_aclgroup(host, aclgroup_names=acls)
231
232        # Though we can return the host object above, this proves that the host
233        # actually got saved in the database. For example, this will return none
234        # if save() wasn't called on the model.Host instance.
235        return cls.get_host(hostname=name)[0]
236
237
238    @classmethod
239    def update_hqe(cls, hqe_id, **kwargs):
240        """Update the hqe with the given kwargs.
241
242        @param hqe_id: The id of the hqe to update.
243        """
244        models.HostQueueEntry.objects.filter(id=hqe_id).update(**kwargs)
245
246
247    @classmethod
248    def update_special_task(cls, task_id, **kwargs):
249        """Update special tasks with the given kwargs.
250
251        @param task_id: The if of the task to update.
252        """
253        models.SpecialTask.objects.filter(id=task_id).update(**kwargs)
254
255
256    @classmethod
257    def add_host_to_job(cls, host, job_id, activate=0):
258        """Add a host to the hqe of a job.
259
260        @param host: An instance of the host model.
261        @param job_id: The job to which we need to add the host.
262        @param activate: If true, flip the active bit on the hqe.
263
264        @raises ValueError: If the hqe for the job already has a host,
265            or if the host argument isn't a Host instance.
266        """
267        hqe = models.HostQueueEntry.objects.get(job_id=job_id)
268        if hqe.host:
269            raise ValueError('HQE for job %s already has a host' % job_id)
270        hqe.host = host
271        hqe.save()
272        if activate:
273            cls.update_hqe(hqe.id, active=True)
274
275
276    @classmethod
277    def increment_priority(cls, job_id):
278        job = models.Job.objects.get(id=job_id)
279        job.priority = job.priority + 1
280        job.save()
281
282
283class FileDatabaseHelper(object):
284    """A helper class to setup a SQLite database backed by a file.
285
286    Note that initializing a file database takes significantly longer than an
287    in-memory database and should only be used for functional tests.
288    """
289
290    DB_FILE = os.path.join(common.autotest_dir, 'host_scheduler_db')
291
292    def initialize_database_for_testing(self, db_file_path=None):
293        """Initialize a SQLite database for testing.
294
295        To force monitor_db and the host_scheduler to use the same SQLite file
296        database, call this method before initializing the database through
297        frontend_test_utils. The host_scheduler is setup to look for the
298        host_scheduler_db when invoked with --testing.
299
300        @param db_file_path: The name of the file to use to create
301            a SQLite database. Since this database is shared across different
302            processes using a file is closer to the real world.
303        """
304        if not db_file_path:
305            db_file_path = self.DB_FILE
306        # TODO: Move the translating database elsewhere. Monitor_db circular
307        # imports host_scheduler.
308        from autotest_lib.frontend import setup_test_environment
309        from django.conf import settings
310        self.old_django_db_name = settings.DATABASES['default']['NAME']
311        settings.DATABASES['default']['NAME'] = db_file_path
312        self.db_file_path = db_file_path
313        _db_manager = scheduler_lib.ConnectionManager(autocommit=False)
314        _db_manager.db_connection = (
315                database_connection.TranslatingDatabase.get_test_database(
316                translators=scheduler_lib._DB_TRANSLATORS))
317
318
319    def teardown_file_database(self):
320        """Teardown django database settings."""
321        # TODO: Move the translating database elsewhere. Monitor_db circular
322        # imports host_scheduler.
323        from django.conf import settings
324        settings.DATABASES['default']['NAME'] = self.old_django_db_name
325        try:
326            os.remove(self.db_file_path)
327        except (OSError, AttributeError):
328            pass
329
330
331class AbstractBaseRDBTester(frontend_test_utils.FrontendTestMixin):
332
333    __meta__ = abc.ABCMeta
334    _config_section = 'AUTOTEST_WEB'
335
336
337    @staticmethod
338    def get_request(dep_names, acl_names, priority=0, parent_job_id=0):
339        deps = [dep.id for dep in DBHelper.get_labels(name__in=dep_names)]
340        acls = [acl.id for acl in DBHelper.get_acls(name__in=acl_names)]
341        return rdb_requests.AcquireHostRequest(
342                        deps=deps, acls=acls, host_id=None, priority=priority,
343                        parent_job_id=parent_job_id)._request
344
345
346    def _release_unused_hosts(self):
347        """Release all hosts unused by an active hqe. """
348        self.host_scheduler.tick()
349
350
351    def setUp(self, inline_host_acquisition=True, setup_tables=True):
352        """Common setup module for tests that need a jobs/host database.
353
354        @param inline_host_acquisition: If True, the dispatcher tries to acquire
355            hosts inline with the rest of the tick.
356        """
357        self.db_helper = DBHelper()
358        self._database = self.db_helper.database
359        # Runs syncdb setting up initial database conditions
360        self._frontend_common_setup(setup_tables=setup_tables)
361        connection_manager = scheduler_lib.ConnectionManager(autocommit=False)
362        self.god.stub_with(connection_manager, 'db_connection', self._database)
363        self.god.stub_with(monitor_db, '_db_manager', connection_manager)
364        self.god.stub_with(scheduler_models, '_db', self._database)
365        self.god.stub_with(monitor_db, '_inline_host_acquisition',
366                           inline_host_acquisition)
367        self._dispatcher = monitor_db.Dispatcher()
368        self.host_scheduler = self._dispatcher._host_scheduler
369        self.host_query_manager = query_managers.AFEHostQueryManager()
370        self.job_query_manager = self._dispatcher._job_query_manager
371        self._release_unused_hosts()
372
373
374    def tearDown(self):
375        self.god.unstub_all()
376        self._database.disconnect()
377        self._frontend_common_teardown()
378
379
380    def create_job(self, user='autotest_system',
381                   deps=set([]), acls=set([]), hostless_job=False,
382                   priority=0, parent_job_id=None, shard_hostname=None):
383        """Create a job owned by user, with the deps and acls specified.
384
385        This method is a wrapper around frontend_test_utils.create_job, that
386        also takes care of creating the appropriate deps for a job, and the
387        appropriate acls for the given user.
388
389        @raises ValueError: If no deps are specified for a job, since all jobs
390            need at least the metahost.
391        @raises AssertionError: If no hqe was created for the job.
392
393        @return: An instance of the job model associated with the new job.
394        """
395        # This is a slight hack around the implementation of
396        # scheduler_models.is_hostless_job, even though a metahost is just
397        # another label to the rdb.
398        if not deps:
399            raise ValueError('Need at least one dep for metahost')
400
401        # TODO: This is a hack around the fact that frontend_test_utils still
402        # need a metahost, but metahost is treated like any other label.
403        metahost = self.db_helper.create_label(list(deps)[0])
404        job = self._create_job(metahosts=[metahost.id], priority=priority,
405                owner=user, parent_job_id=parent_job_id)
406        self.assert_(len(job.hostqueueentry_set.all()) == 1)
407
408        self.db_helper.add_deps_to_job(job, dep_names=list(deps)[1:])
409        self.db_helper.add_user_to_aclgroups(user, aclgroup_names=acls)
410        if shard_hostname:
411            self.db_helper.assign_job_to_shard(job.id, shard_hostname)
412        return models.Job.objects.filter(id=job.id)[0]
413
414
415    def assert_host_db_status(self, host_id):
416        """Assert host state right after acquisition.
417
418        Call this method to check the status of any host leased by the
419        rdb before it has been assigned to an hqe. It must be leased and
420        ready at this point in time.
421
422        @param host_id: Id of the host to check.
423
424        @raises AssertionError: If the host is either not leased or Ready.
425        """
426        host = models.Host.objects.get(id=host_id)
427        self.assert_(host.leased)
428        self.assert_(host.status == 'Ready')
429
430
431    def check_hosts(self, host_iter):
432        """Sanity check all hosts in the host_gen.
433
434        @param host_iter: A generator/iterator of RDBClientHostWrappers.
435            eg: The generator returned by rdb_lib.acquire_hosts. If a request
436            was not satisfied this iterator can contain None.
437
438        @raises AssertionError: If any of the sanity checks fail.
439        """
440        for host in host_iter:
441            if host:
442                self.assert_host_db_status(host.id)
443                self.assert_(host.leased == 1)
444
445
446    def create_suite(self, user='autotest_system', num=2, priority=0,
447                     board='z', build='x', acls=set()):
448        """Create num jobs with the same parent_job_id, board, build, priority.
449
450        @return: A dictionary with the parent job object keyed as 'parent_job'
451            and all other jobs keyed at an index from 0-num.
452        """
453        jobs = {}
454        # Create a hostless parent job without an hqe or deps. Since the
455        # hostless job does nothing, we need to hand craft cros-version.
456        parent_job = self._create_job(owner=user, priority=priority)
457        jobs['parent_job'] = parent_job
458        build = '%s:%s' % (provision.CROS_VERSION_PREFIX, build)
459        for job_index in range(0, num):
460            jobs[job_index] = self.create_job(user=user, priority=priority,
461                                              deps=set([board, build]),
462                                              acls=acls,
463                                              parent_job_id=parent_job.id)
464        return jobs
465
466
467    def check_host_assignment(self, job_id, host_id):
468        """Check is a job<->host assignment is valid.
469
470        Uses the deps of a job and the aclgroups the owner of the job is
471        in to see if the given host can be used to run the given job. Also
472        checks that the host-job assignment has Not been made, but that the
473        host is no longer in the available hosts pool.
474
475        Use this method to check host assignements made by the rdb, Before
476        they're handed off to the scheduler, since the scheduler.
477
478        @param job_id: The id of the job to use in the compatibility check.
479        @param host_id: The id of the host to check for compatibility.
480
481        @raises AssertionError: If the job and the host are incompatible.
482        """
483        job = models.Job.objects.get(id=job_id)
484        host = models.Host.objects.get(id=host_id)
485        hqe = job.hostqueueentry_set.all()[0]
486
487        # Confirm that the host has not been assigned, either to another hqe
488        # or the this one.
489        all_hqes = models.HostQueueEntry.objects.filter(
490                host_id=host_id, complete=0)
491        self.assert_(len(all_hqes) <= 1)
492        self.assert_(hqe.host_id == None)
493        self.assert_host_db_status(host_id)
494
495        # Assert that all deps of the job are satisfied.
496        job_deps = set([d.name for d in job.dependency_labels.all()])
497        host_labels = set([l.name for l in host.labels.all()])
498        self.assert_(job_deps.intersection(host_labels) == job_deps)
499
500        # Assert that the owner of the job is in at least one of the
501        # groups that owns the host.
502        job_owner_aclgroups = set([job_acl.name for job_acl
503                                   in job.user().aclgroup_set.all()])
504        host_aclgroups = set([host_acl.name for host_acl
505                              in host.aclgroup_set.all()])
506        self.assert_(job_owner_aclgroups.intersection(host_aclgroups))
507
508
509