1#!/usr/bin/env python2
2# pylint: disable=missing-docstring
3
4import datetime
5import mox
6import unittest
7
8import common
9from autotest_lib.client.common_lib import control_data
10from autotest_lib.client.common_lib import error
11from autotest_lib.client.common_lib import global_config
12from autotest_lib.client.common_lib import priorities
13from autotest_lib.client.common_lib.cros import dev_server
14from autotest_lib.client.common_lib.test_utils import mock
15from autotest_lib.frontend import setup_django_environment
16from autotest_lib.frontend.afe import frontend_test_utils
17from autotest_lib.frontend.afe import model_logic
18from autotest_lib.frontend.afe import models
19from autotest_lib.frontend.afe import rpc_interface
20from autotest_lib.frontend.afe import rpc_utils
21from autotest_lib.server import frontend
22from autotest_lib.server import utils as server_utils
23from autotest_lib.server.cros import provision
24from autotest_lib.server.cros.dynamic_suite import constants
25from autotest_lib.server.cros.dynamic_suite import control_file_getter
26from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
27from autotest_lib.server.cros.dynamic_suite import suite_common
28
29
30CLIENT = control_data.CONTROL_TYPE_NAMES.CLIENT
31SERVER = control_data.CONTROL_TYPE_NAMES.SERVER
32
33_hqe_status = models.HostQueueEntry.Status
34
35
36class ShardHeartbeatTest(mox.MoxTestBase, unittest.TestCase):
37
38    _PRIORITY = priorities.Priority.DEFAULT
39
40
41    def _do_heartbeat_and_assert_response(self, shard_hostname='shard1',
42                                          upload_jobs=(), upload_hqes=(),
43                                          known_jobs=(), known_hosts=(),
44                                          **kwargs):
45        known_job_ids = [job.id for job in known_jobs]
46        known_host_ids = [host.id for host in known_hosts]
47        known_host_statuses = [host.status for host in known_hosts]
48
49        retval = rpc_interface.shard_heartbeat(
50            shard_hostname=shard_hostname,
51            jobs=upload_jobs, hqes=upload_hqes,
52            known_job_ids=known_job_ids, known_host_ids=known_host_ids,
53            known_host_statuses=known_host_statuses)
54
55        self._assert_shard_heartbeat_response(shard_hostname, retval,
56                                              **kwargs)
57
58        return shard_hostname
59
60
61    def _assert_shard_heartbeat_response(self, shard_hostname, retval, jobs=[],
62                                         hosts=[], hqes=[],
63                                         incorrect_host_ids=[]):
64
65        retval_hosts, retval_jobs = retval['hosts'], retval['jobs']
66        retval_incorrect_hosts = retval['incorrect_host_ids']
67
68        expected_jobs = [
69            (job.id, job.name, shard_hostname) for job in jobs]
70        returned_jobs = [(job['id'], job['name'], job['shard']['hostname'])
71                         for job in retval_jobs]
72        self.assertEqual(returned_jobs, expected_jobs)
73
74        expected_hosts = [(host.id, host.hostname) for host in hosts]
75        returned_hosts = [(host['id'], host['hostname'])
76                          for host in retval_hosts]
77        self.assertEqual(returned_hosts, expected_hosts)
78
79        retval_hqes = []
80        for job in retval_jobs:
81            retval_hqes += job['hostqueueentry_set']
82
83        expected_hqes = [(hqe.id) for hqe in hqes]
84        returned_hqes = [(hqe['id']) for hqe in retval_hqes]
85        self.assertEqual(returned_hqes, expected_hqes)
86
87        self.assertEqual(retval_incorrect_hosts, incorrect_host_ids)
88
89
90    def _createJobForLabel(self, label):
91        job_id = rpc_interface.create_job(name='dummy', priority=self._PRIORITY,
92                                          control_file='foo',
93                                          control_type=CLIENT,
94                                          meta_hosts=[label.name],
95                                          dependencies=(label.name,))
96        return models.Job.objects.get(id=job_id)
97
98
99    def _testShardHeartbeatFetchHostlessJobHelper(self, host1):
100        """Create a hostless job and ensure it's not assigned to a shard."""
101        label2 = models.Label.objects.create(name='bluetooth', platform=False)
102
103        job1 = self._create_job(hostless=True)
104
105        # Hostless jobs should be executed by the global scheduler.
106        self._do_heartbeat_and_assert_response(hosts=[host1])
107
108
109    def _testShardHeartbeatIncorrectHostsHelper(self, host1):
110        """Ensure that hosts that don't belong to shard are determined."""
111        host2 = models.Host.objects.create(hostname='test_host2', leased=False)
112
113        # host2 should not belong to shard1. Ensure that if shard1 thinks host2
114        # is a known host, then it is returned as invalid.
115        self._do_heartbeat_and_assert_response(known_hosts=[host1, host2],
116                                               incorrect_host_ids=[host2.id])
117
118
119    def _testShardHeartbeatLabelRemovalRaceHelper(self, shard1, host1, label1):
120        """Ensure correctness if label removed during heartbeat."""
121        host2 = models.Host.objects.create(hostname='test_host2', leased=False)
122        host2.labels.add(label1)
123        self.assertEqual(host2.shard, None)
124
125        # In the middle of the assign_to_shard call, remove label1 from shard1.
126        self.mox.StubOutWithMock(models.Host, '_assign_to_shard_nothing_helper')
127        def remove_label():
128            rpc_interface.remove_board_from_shard(shard1.hostname, label1.name)
129
130        models.Host._assign_to_shard_nothing_helper().WithSideEffects(
131            remove_label)
132        self.mox.ReplayAll()
133
134        self._do_heartbeat_and_assert_response(
135            known_hosts=[host1], hosts=[], incorrect_host_ids=[host1.id])
136        host2 = models.Host.smart_get(host2.id)
137        self.assertEqual(host2.shard, None)
138
139
140    def _testShardRetrieveJobsHelper(self, shard1, host1, label1, shard2,
141                                     host2, label2):
142        """Create jobs and retrieve them."""
143        # should never be returned by heartbeat
144        leased_host = models.Host.objects.create(hostname='leased_host',
145                                                 leased=True)
146
147        leased_host.labels.add(label1)
148
149        job1 = self._createJobForLabel(label1)
150
151        job2 = self._createJobForLabel(label2)
152
153        job_completed = self._createJobForLabel(label1)
154        # Job is already being run, so don't sync it
155        job_completed.hostqueueentry_set.update(complete=True)
156        job_completed.hostqueueentry_set.create(complete=False)
157
158        job_active = self._createJobForLabel(label1)
159        # Job is already started, so don't sync it
160        job_active.hostqueueentry_set.update(active=True)
161        job_active.hostqueueentry_set.create(complete=False, active=False)
162
163        self._do_heartbeat_and_assert_response(
164            jobs=[job1], hosts=[host1], hqes=job1.hostqueueentry_set.all())
165
166        self._do_heartbeat_and_assert_response(
167            shard_hostname=shard2.hostname,
168            jobs=[job2], hosts=[host2], hqes=job2.hostqueueentry_set.all())
169
170        host3 = models.Host.objects.create(hostname='test_host3', leased=False)
171        host3.labels.add(label1)
172
173        self._do_heartbeat_and_assert_response(
174            known_jobs=[job1], known_hosts=[host1], hosts=[host3])
175
176
177    def _testResendJobsAfterFailedHeartbeatHelper(self, shard1, host1, label1):
178        """Create jobs, retrieve them, fail on client, fetch them again."""
179        job1 = self._createJobForLabel(label1)
180
181        self._do_heartbeat_and_assert_response(
182            jobs=[job1],
183            hqes=job1.hostqueueentry_set.all(), hosts=[host1])
184
185        # Make sure it's resubmitted by sending last_job=None again
186        self._do_heartbeat_and_assert_response(
187            known_hosts=[host1],
188            jobs=[job1], hqes=job1.hostqueueentry_set.all(), hosts=[])
189
190        # Now it worked, make sure it's not sent again
191        self._do_heartbeat_and_assert_response(
192            known_jobs=[job1], known_hosts=[host1])
193
194        job1 = models.Job.objects.get(pk=job1.id)
195        job1.hostqueueentry_set.all().update(complete=True)
196
197        # Job is completed, make sure it's not sent again
198        self._do_heartbeat_and_assert_response(
199            known_hosts=[host1])
200
201        job2 = self._createJobForLabel(label1)
202
203        # job2's creation was later, it should be returned now.
204        self._do_heartbeat_and_assert_response(
205            known_hosts=[host1],
206            jobs=[job2], hqes=job2.hostqueueentry_set.all())
207
208        self._do_heartbeat_and_assert_response(
209            known_jobs=[job2], known_hosts=[host1])
210
211        job2 = models.Job.objects.get(pk=job2.pk)
212        job2.hostqueueentry_set.update(aborted=True)
213        # Setting a job to a complete status will set the shard_id to None in
214        # scheduler_models. We have to emulate that here, because we use Django
215        # models in tests.
216        job2.shard = None
217        job2.save()
218
219        self._do_heartbeat_and_assert_response(
220            known_jobs=[job2], known_hosts=[host1],
221            jobs=[job2],
222            hqes=job2.hostqueueentry_set.all())
223
224        models.Test.objects.create(name='platform_BootPerfServer:shard',
225                                   test_type=1)
226        self.mox.StubOutWithMock(server_utils, 'read_file')
227        self.mox.ReplayAll()
228        rpc_interface.delete_shard(hostname=shard1.hostname)
229
230        self.assertRaises(
231            models.Shard.DoesNotExist, models.Shard.objects.get, pk=shard1.id)
232
233        job1 = models.Job.objects.get(pk=job1.id)
234        label1 = models.Label.objects.get(pk=label1.id)
235
236        self.assertIsNone(job1.shard)
237        self.assertEqual(len(label1.shard_set.all()), 0)
238
239
240    def _testResendHostsAfterFailedHeartbeatHelper(self, host1):
241        """Check that main accepts resending updated records after failure."""
242        # Send the host
243        self._do_heartbeat_and_assert_response(hosts=[host1])
244
245        # Send it again because previous one didn't persist correctly
246        self._do_heartbeat_and_assert_response(hosts=[host1])
247
248        # Now it worked, make sure it isn't sent again
249        self._do_heartbeat_and_assert_response(known_hosts=[host1])
250
251
252class RpcInterfaceTestWithStaticAttribute(
253        mox.MoxTestBase, unittest.TestCase,
254        frontend_test_utils.FrontendTestMixin):
255
256    def setUp(self):
257        super(RpcInterfaceTestWithStaticAttribute, self).setUp()
258        self._frontend_common_setup()
259        self.god = mock.mock_god()
260        self.old_respect_static_config = rpc_interface.RESPECT_STATIC_ATTRIBUTES
261        rpc_interface.RESPECT_STATIC_ATTRIBUTES = True
262        models.RESPECT_STATIC_ATTRIBUTES = True
263
264
265    def tearDown(self):
266        self.god.unstub_all()
267        self._frontend_common_teardown()
268        global_config.global_config.reset_config_values()
269        rpc_interface.RESPECT_STATIC_ATTRIBUTES = self.old_respect_static_config
270        models.RESPECT_STATIC_ATTRIBUTES = self.old_respect_static_config
271
272
273    def _fake_host_with_static_attributes(self):
274        host1 = models.Host.objects.create(hostname='test_host')
275        host1.set_attribute('test_attribute1', 'test_value1')
276        host1.set_attribute('test_attribute2', 'test_value2')
277        self._set_static_attribute(host1, 'test_attribute1', 'static_value1')
278        self._set_static_attribute(host1, 'static_attribute1', 'static_value2')
279        host1.save()
280        return host1
281
282
283    def test_get_hosts(self):
284        host1 = self._fake_host_with_static_attributes()
285        hosts = rpc_interface.get_hosts(hostname=host1.hostname)
286        host = hosts[0]
287
288        self.assertEquals(host['hostname'], 'test_host')
289        self.assertEquals(host['acls'], ['Everyone'])
290        # Respect the value of static attributes.
291        self.assertEquals(host['attributes'],
292                          {'test_attribute1': 'static_value1',
293                           'test_attribute2': 'test_value2',
294                           'static_attribute1': 'static_value2'})
295
296    def test_get_host_attribute_with_static(self):
297        host1 = models.Host.objects.create(hostname='test_host1')
298        host1.set_attribute('test_attribute1', 'test_value1')
299        self._set_static_attribute(host1, 'test_attribute1', 'static_value1')
300        host2 = models.Host.objects.create(hostname='test_host2')
301        host2.set_attribute('test_attribute1', 'test_value1')
302        host2.set_attribute('test_attribute2', 'test_value2')
303
304        attributes = rpc_interface.get_host_attribute(
305                'test_attribute1',
306                hostname__in=['test_host1', 'test_host2'])
307        hosts = [attr['host'] for attr in attributes]
308        values = [attr['value'] for attr in attributes]
309        self.assertEquals(set(hosts),
310                          set(['test_host1', 'test_host2']))
311        self.assertEquals(set(values),
312                          set(['test_value1', 'static_value1']))
313
314
315    def test_get_hosts_by_attribute_without_static(self):
316        host1 = models.Host.objects.create(hostname='test_host1')
317        host1.set_attribute('test_attribute1', 'test_value1')
318        host2 = models.Host.objects.create(hostname='test_host2')
319        host2.set_attribute('test_attribute1', 'test_value1')
320
321        hosts = rpc_interface.get_hosts_by_attribute(
322                'test_attribute1', 'test_value1')
323        self.assertEquals(set(hosts),
324                          set(['test_host1', 'test_host2']))
325
326
327    def test_get_hosts_by_attribute_with_static(self):
328        host1 = models.Host.objects.create(hostname='test_host1')
329        host1.set_attribute('test_attribute1', 'test_value1')
330        self._set_static_attribute(host1, 'test_attribute1', 'test_value1')
331        host2 = models.Host.objects.create(hostname='test_host2')
332        host2.set_attribute('test_attribute1', 'test_value1')
333        self._set_static_attribute(host2, 'test_attribute1', 'static_value1')
334        host3 = models.Host.objects.create(hostname='test_host3')
335        self._set_static_attribute(host3, 'test_attribute1', 'test_value1')
336        host4 = models.Host.objects.create(hostname='test_host4')
337        host4.set_attribute('test_attribute1', 'test_value1')
338        host5 = models.Host.objects.create(hostname='test_host5')
339        host5.set_attribute('test_attribute1', 'temp_value1')
340        self._set_static_attribute(host5, 'test_attribute1', 'test_value1')
341
342        hosts = rpc_interface.get_hosts_by_attribute(
343                'test_attribute1', 'test_value1')
344        # host1: matched, it has the same value for test_attribute1.
345        # host2: not matched, it has a new value in
346        #        afe_static_host_attributes for test_attribute1.
347        # host3: matched, it has a corresponding entry in
348        #        afe_host_attributes for test_attribute1.
349        # host4: matched, test_attribute1 is not replaced by static
350        #        attribute.
351        # host5: matched, it has an updated & matched value for
352        #        test_attribute1 in afe_static_host_attributes.
353        self.assertEquals(set(hosts),
354                          set(['test_host1', 'test_host3',
355                               'test_host4', 'test_host5']))
356
357
358class RpcInterfaceTestWithStaticLabel(ShardHeartbeatTest,
359                                      frontend_test_utils.FrontendTestMixin):
360
361    _STATIC_LABELS = ['board:lumpy']
362
363    def setUp(self):
364        super(RpcInterfaceTestWithStaticLabel, self).setUp()
365        self._frontend_common_setup()
366        self.god = mock.mock_god()
367        self.old_respect_static_config = rpc_interface.RESPECT_STATIC_LABELS
368        rpc_interface.RESPECT_STATIC_LABELS = True
369        models.RESPECT_STATIC_LABELS = True
370
371
372    def tearDown(self):
373        self.god.unstub_all()
374        self._frontend_common_teardown()
375        global_config.global_config.reset_config_values()
376        rpc_interface.RESPECT_STATIC_LABELS = self.old_respect_static_config
377        models.RESPECT_STATIC_LABELS = self.old_respect_static_config
378
379
380    def _fake_host_with_static_labels(self):
381        host1 = models.Host.objects.create(hostname='test_host')
382        label1 = models.Label.objects.create(
383                name='non_static_label1', platform=False)
384        non_static_platform = models.Label.objects.create(
385                name='static_platform', platform=False)
386        static_platform = models.StaticLabel.objects.create(
387                name='static_platform', platform=True)
388        models.ReplacedLabel.objects.create(label_id=non_static_platform.id)
389        host1.static_labels.add(static_platform)
390        host1.labels.add(non_static_platform)
391        host1.labels.add(label1)
392        host1.save()
393        return host1
394
395
396    def test_get_hosts(self):
397        host1 = self._fake_host_with_static_labels()
398        hosts = rpc_interface.get_hosts(hostname=host1.hostname)
399        host = hosts[0]
400
401        self.assertEquals(host['hostname'], 'test_host')
402        self.assertEquals(host['acls'], ['Everyone'])
403        # Respect all labels in afe_hosts_labels.
404        self.assertEquals(host['labels'],
405                          ['non_static_label1', 'static_platform'])
406        # Respect static labels.
407        self.assertEquals(host['platform'], 'static_platform')
408
409
410    def test_get_hosts_multiple_labels(self):
411        self._fake_host_with_static_labels()
412        hosts = rpc_interface.get_hosts(
413                multiple_labels=['non_static_label1', 'static_platform'])
414        host = hosts[0]
415        self.assertEquals(host['hostname'], 'test_host')
416
417
418    def test_delete_static_label(self):
419        label1 = models.Label.smart_get('static')
420
421        host2 = models.Host.objects.all()[1]
422        shard1 = models.Shard.objects.create(hostname='shard1')
423        host2.shard = shard1
424        host2.labels.add(label1)
425        host2.save()
426
427        mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
428                                                  'MockAFE')
429        self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
430
431        self.assertRaises(error.UnmodifiableLabelException,
432                          rpc_interface.delete_label,
433                          label1.id)
434
435        self.god.check_playback()
436
437
438    def test_modify_static_label(self):
439        label1 = models.Label.smart_get('static')
440        self.assertEqual(label1.invalid, 0)
441
442        host2 = models.Host.objects.all()[1]
443        shard1 = models.Shard.objects.create(hostname='shard1')
444        host2.shard = shard1
445        host2.labels.add(label1)
446        host2.save()
447
448        mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
449                                                  'MockAFE')
450        self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
451
452        self.assertRaises(error.UnmodifiableLabelException,
453                          rpc_interface.modify_label,
454                          label1.id,
455                          invalid=1)
456
457        self.assertEqual(models.Label.smart_get('static').invalid, 0)
458        self.god.check_playback()
459
460
461    def test_multiple_platforms_add_non_static_to_static(self):
462        """Test non-static platform to a host with static platform."""
463        static_platform = models.StaticLabel.objects.create(
464                name='static_platform', platform=True)
465        non_static_platform = models.Label.objects.create(
466                name='static_platform', platform=True)
467        models.ReplacedLabel.objects.create(label_id=non_static_platform.id)
468        platform2 = models.Label.objects.create(name='platform2', platform=True)
469        host1 = models.Host.objects.create(hostname='test_host')
470        host1.static_labels.add(static_platform)
471        host1.labels.add(non_static_platform)
472        host1.save()
473
474        self.assertRaises(model_logic.ValidationError,
475                          rpc_interface.label_add_hosts, id='platform2',
476                          hosts=['test_host'])
477        self.assertRaises(model_logic.ValidationError,
478                          rpc_interface.host_add_labels,
479                          id='test_host', labels=['platform2'])
480        # make sure the platform didn't get added
481        platforms = rpc_interface.get_labels(
482            host__hostname__in=['test_host'], platform=True)
483        self.assertEquals(len(platforms), 1)
484
485
486    def test_multiple_platforms_add_static_to_non_static(self):
487        """Test static platform to a host with non-static platform."""
488        platform1 = models.Label.objects.create(
489                name='static_platform', platform=True)
490        models.ReplacedLabel.objects.create(label_id=platform1.id)
491        static_platform = models.StaticLabel.objects.create(
492                name='static_platform', platform=True)
493        platform2 = models.Label.objects.create(
494                name='platform2', platform=True)
495
496        host1 = models.Host.objects.create(hostname='test_host')
497        host1.labels.add(platform2)
498        host1.save()
499
500        self.assertRaises(model_logic.ValidationError,
501                          rpc_interface.label_add_hosts,
502                          id='static_platform',
503                          hosts=['test_host'])
504        self.assertRaises(model_logic.ValidationError,
505                          rpc_interface.host_add_labels,
506                          id='test_host', labels=['static_platform'])
507        # make sure the platform didn't get added
508        platforms = rpc_interface.get_labels(
509            host__hostname__in=['test_host'], platform=True)
510        self.assertEquals(len(platforms), 1)
511
512
513    def test_label_remove_hosts(self):
514        """Test remove a label of hosts."""
515        label = models.Label.smart_get('static')
516        static_label = models.StaticLabel.objects.create(name='static')
517
518        host1 = models.Host.objects.create(hostname='test_host')
519        host1.labels.add(label)
520        host1.static_labels.add(static_label)
521        host1.save()
522
523        self.assertRaises(error.UnmodifiableLabelException,
524                          rpc_interface.label_remove_hosts,
525                          id='static', hosts=['test_host'])
526
527
528    def test_host_remove_labels(self):
529        """Test remove labels of a given host."""
530        label = models.Label.smart_get('static')
531        label1 = models.Label.smart_get('label1')
532        label2 = models.Label.smart_get('label2')
533        static_label = models.StaticLabel.objects.create(name='static')
534
535        host1 = models.Host.objects.create(hostname='test_host')
536        host1.labels.add(label)
537        host1.labels.add(label1)
538        host1.labels.add(label2)
539        host1.static_labels.add(static_label)
540        host1.save()
541
542        rpc_interface.host_remove_labels(
543                'test_host', ['static', 'label1'])
544        labels = rpc_interface.get_labels(host__hostname__in=['test_host'])
545        # Only non_static label 'label1' is removed.
546        self.assertEquals(len(labels), 2)
547        self.assertEquals(labels[0].get('name'), 'label2')
548
549
550    def test_remove_board_from_shard(self):
551        """test remove a board (static label) from shard."""
552        label = models.Label.smart_get('static')
553        static_label = models.StaticLabel.objects.create(name='static')
554
555        shard = models.Shard.objects.create(hostname='test_shard')
556        shard.labels.add(label)
557
558        host = models.Host.objects.create(hostname='test_host',
559                                          leased=False,
560                                          shard=shard)
561        host.static_labels.add(static_label)
562        host.save()
563
564        rpc_interface.remove_board_from_shard(shard.hostname, label.name)
565        host1 = models.Host.smart_get(host.id)
566        shard1 = models.Shard.smart_get(shard.id)
567        self.assertEqual(host1.shard, None)
568        self.assertItemsEqual(shard1.labels.all(), [])
569
570
571    def test_check_job_dependencies_success(self):
572        """Test check_job_dependencies successfully."""
573        static_label = models.StaticLabel.objects.create(name='static')
574
575        host = models.Host.objects.create(hostname='test_host')
576        host.static_labels.add(static_label)
577        host.save()
578
579        host1 = models.Host.smart_get(host.id)
580        rpc_utils.check_job_dependencies([host1], ['static'])
581
582
583    def test_check_job_dependencies_fail(self):
584        """Test check_job_dependencies with raising ValidationError."""
585        label = models.Label.smart_get('static')
586        static_label = models.StaticLabel.objects.create(name='static')
587
588        host = models.Host.objects.create(hostname='test_host')
589        host.labels.add(label)
590        host.save()
591
592        host1 = models.Host.smart_get(host.id)
593        self.assertRaises(model_logic.ValidationError,
594                          rpc_utils.check_job_dependencies,
595                          [host1],
596                          ['static'])
597
598    def test_check_job_metahost_dependencies_success(self):
599        """Test check_job_metahost_dependencies successfully."""
600        label1 = models.Label.smart_get('label1')
601        label2 = models.Label.smart_get('label2')
602        label = models.Label.smart_get('static')
603        static_label = models.StaticLabel.objects.create(name='static')
604
605        host = models.Host.objects.create(hostname='test_host')
606        host.static_labels.add(static_label)
607        host.labels.add(label1)
608        host.labels.add(label2)
609        host.save()
610
611        rpc_utils.check_job_metahost_dependencies(
612                [label1, label], [label2.name])
613        rpc_utils.check_job_metahost_dependencies(
614                [label1], [label2.name, static_label.name])
615
616
617    def test_check_job_metahost_dependencies_fail(self):
618        """Test check_job_metahost_dependencies with raising errors."""
619        label1 = models.Label.smart_get('label1')
620        label2 = models.Label.smart_get('label2')
621        label = models.Label.smart_get('static')
622        static_label = models.StaticLabel.objects.create(name='static')
623
624        host = models.Host.objects.create(hostname='test_host')
625        host.labels.add(label1)
626        host.labels.add(label2)
627        host.save()
628
629        self.assertRaises(error.NoEligibleHostException,
630                          rpc_utils.check_job_metahost_dependencies,
631                          [label1, label], [label2.name])
632        self.assertRaises(error.NoEligibleHostException,
633                          rpc_utils.check_job_metahost_dependencies,
634                          [label1], [label2.name, static_label.name])
635
636
637    def _createShardAndHostWithStaticLabel(self,
638                                           shard_hostname='shard1',
639                                           host_hostname='test_host1',
640                                           label_name='board:lumpy'):
641        label = models.Label.objects.create(name=label_name)
642
643        shard = models.Shard.objects.create(hostname=shard_hostname)
644        shard.labels.add(label)
645
646        host = models.Host.objects.create(hostname=host_hostname, leased=False,
647                                          shard=shard)
648        host.labels.add(label)
649        if label_name in self._STATIC_LABELS:
650            models.ReplacedLabel.objects.create(label_id=label.id)
651            static_label = models.StaticLabel.objects.create(name=label_name)
652            host.static_labels.add(static_label)
653
654        return shard, host, label
655
656
657    def testShardHeartbeatFetchHostlessJob(self):
658        shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
659                host_hostname='test_host1')
660        self._testShardHeartbeatFetchHostlessJobHelper(host1)
661
662
663    def testShardHeartbeatIncorrectHosts(self):
664        shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
665                host_hostname='test_host1')
666        self._testShardHeartbeatIncorrectHostsHelper(host1)
667
668
669    def testShardHeartbeatLabelRemovalRace(self):
670        shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
671                host_hostname='test_host1')
672        self._testShardHeartbeatLabelRemovalRaceHelper(shard1, host1, label1)
673
674
675    def testShardRetrieveJobs(self):
676        shard1, host1, label1 = self._createShardAndHostWithStaticLabel()
677        shard2, host2, label2 = self._createShardAndHostWithStaticLabel(
678            'shard2', 'test_host2', 'board:grumpy')
679        self._testShardRetrieveJobsHelper(shard1, host1, label1,
680                                          shard2, host2, label2)
681
682
683    def testResendJobsAfterFailedHeartbeat(self):
684        shard1, host1, label1 = self._createShardAndHostWithStaticLabel()
685        self._testResendJobsAfterFailedHeartbeatHelper(shard1, host1, label1)
686
687
688    def testResendHostsAfterFailedHeartbeat(self):
689        shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
690                host_hostname='test_host1')
691        self._testResendHostsAfterFailedHeartbeatHelper(host1)
692
693
694class RpcInterfaceTest(unittest.TestCase,
695                       frontend_test_utils.FrontendTestMixin):
696    def setUp(self):
697        self._frontend_common_setup()
698        self.god = mock.mock_god()
699
700
701    def tearDown(self):
702        self.god.unstub_all()
703        self._frontend_common_teardown()
704        global_config.global_config.reset_config_values()
705
706
707    def test_validation(self):
708        # omit a required field
709        self.assertRaises(model_logic.ValidationError, rpc_interface.add_label,
710                          name=None)
711        # violate uniqueness constraint
712        self.assertRaises(model_logic.ValidationError, rpc_interface.add_host,
713                          hostname='host1')
714
715
716    def test_multiple_platforms(self):
717        platform2 = models.Label.objects.create(name='platform2', platform=True)
718        self.assertRaises(model_logic.ValidationError,
719                          rpc_interface. label_add_hosts, id='platform2',
720                          hosts=['host1', 'host2'])
721        self.assertRaises(model_logic.ValidationError,
722                          rpc_interface.host_add_labels,
723                          id='host1', labels=['platform2'])
724        # make sure the platform didn't get added
725        platforms = rpc_interface.get_labels(
726            host__hostname__in=['host1', 'host2'], platform=True)
727        self.assertEquals(len(platforms), 1)
728        self.assertEquals(platforms[0]['name'], 'myplatform')
729
730
731    def _check_hostnames(self, hosts, expected_hostnames):
732        self.assertEquals(set(host['hostname'] for host in hosts),
733                          set(expected_hostnames))
734
735
736    def test_ping_db(self):
737        self.assertEquals(rpc_interface.ping_db(), [True])
738
739
740    def test_get_hosts_by_attribute(self):
741        host1 = models.Host.objects.create(hostname='test_host1')
742        host1.set_attribute('test_attribute1', 'test_value1')
743        host2 = models.Host.objects.create(hostname='test_host2')
744        host2.set_attribute('test_attribute1', 'test_value1')
745
746        hosts = rpc_interface.get_hosts_by_attribute(
747                'test_attribute1', 'test_value1')
748        self.assertEquals(set(hosts),
749                          set(['test_host1', 'test_host2']))
750
751
752    def test_get_host_attribute(self):
753        host1 = models.Host.objects.create(hostname='test_host1')
754        host1.set_attribute('test_attribute1', 'test_value1')
755        host2 = models.Host.objects.create(hostname='test_host2')
756        host2.set_attribute('test_attribute1', 'test_value1')
757
758        attributes = rpc_interface.get_host_attribute(
759                'test_attribute1',
760                hostname__in=['test_host1', 'test_host2'])
761        hosts = [attr['host'] for attr in attributes]
762        values = [attr['value'] for attr in attributes]
763        self.assertEquals(set(hosts),
764                          set(['test_host1', 'test_host2']))
765        self.assertEquals(set(values), set(['test_value1']))
766
767
768    def test_get_hosts(self):
769        hosts = rpc_interface.get_hosts()
770        self._check_hostnames(hosts, [host.hostname for host in self.hosts])
771
772        hosts = rpc_interface.get_hosts(hostname='host1')
773        self._check_hostnames(hosts, ['host1'])
774        host = hosts[0]
775        self.assertEquals(sorted(host['labels']), ['label1', 'myplatform'])
776        self.assertEquals(host['platform'], 'myplatform')
777        self.assertEquals(host['acls'], ['my_acl'])
778        self.assertEquals(host['attributes'], {})
779
780
781    def test_get_hosts_multiple_labels(self):
782        hosts = rpc_interface.get_hosts(
783                multiple_labels=['myplatform', 'label1'])
784        self._check_hostnames(hosts, ['host1'])
785
786
787    def test_job_keyvals(self):
788        keyval_dict = {'mykey': 'myvalue'}
789        job_id = rpc_interface.create_job(name='test',
790                                          priority=priorities.Priority.DEFAULT,
791                                          control_file='foo',
792                                          control_type=CLIENT,
793                                          hosts=['host1'],
794                                          keyvals=keyval_dict)
795        jobs = rpc_interface.get_jobs(id=job_id)
796        self.assertEquals(len(jobs), 1)
797        self.assertEquals(jobs[0]['keyvals'], keyval_dict)
798
799
800    def test_get_jobs_summary(self):
801        job = self._create_job(hosts=xrange(1, 4))
802        entries = list(job.hostqueueentry_set.all())
803        entries[1].status = _hqe_status.FAILED
804        entries[1].save()
805        entries[2].status = _hqe_status.FAILED
806        entries[2].aborted = True
807        entries[2].save()
808
809        # Mock up tko_rpc_interface.get_status_counts.
810        self.god.stub_function_to_return(rpc_interface.tko_rpc_interface,
811                                         'get_status_counts',
812                                         None)
813
814        job_summaries = rpc_interface.get_jobs_summary(id=job.id)
815        self.assertEquals(len(job_summaries), 1)
816        summary = job_summaries[0]
817        self.assertEquals(summary['status_counts'], {'Queued': 1,
818                                                     'Failed': 2})
819
820
821    def _check_job_ids(self, actual_job_dicts, expected_jobs):
822        self.assertEquals(
823                set(job_dict['id'] for job_dict in actual_job_dicts),
824                set(job.id for job in expected_jobs))
825
826
827    def test_get_jobs_status_filters(self):
828        HqeStatus = models.HostQueueEntry.Status
829        def create_two_host_job():
830            return self._create_job(hosts=[1, 2])
831        def set_hqe_statuses(job, first_status, second_status):
832            entries = job.hostqueueentry_set.all()
833            entries[0].update_object(status=first_status)
834            entries[1].update_object(status=second_status)
835
836        queued = create_two_host_job()
837
838        queued_and_running = create_two_host_job()
839        set_hqe_statuses(queued_and_running, HqeStatus.QUEUED,
840                           HqeStatus.RUNNING)
841
842        running_and_complete = create_two_host_job()
843        set_hqe_statuses(running_and_complete, HqeStatus.RUNNING,
844                           HqeStatus.COMPLETED)
845
846        complete = create_two_host_job()
847        set_hqe_statuses(complete, HqeStatus.COMPLETED, HqeStatus.COMPLETED)
848
849        started_but_inactive = create_two_host_job()
850        set_hqe_statuses(started_but_inactive, HqeStatus.QUEUED,
851                           HqeStatus.COMPLETED)
852
853        parsing = create_two_host_job()
854        set_hqe_statuses(parsing, HqeStatus.PARSING, HqeStatus.PARSING)
855
856        self._check_job_ids(rpc_interface.get_jobs(not_yet_run=True), [queued])
857        self._check_job_ids(rpc_interface.get_jobs(running=True),
858                      [queued_and_running, running_and_complete,
859                       started_but_inactive, parsing])
860        self._check_job_ids(rpc_interface.get_jobs(finished=True), [complete])
861
862
863    def test_get_jobs_type_filters(self):
864        self.assertRaises(AssertionError, rpc_interface.get_jobs,
865                          suite=True, sub=True)
866        self.assertRaises(AssertionError, rpc_interface.get_jobs,
867                          suite=True, standalone=True)
868        self.assertRaises(AssertionError, rpc_interface.get_jobs,
869                          standalone=True, sub=True)
870
871        parent_job = self._create_job(hosts=[1])
872        child_jobs = self._create_job(hosts=[1, 2],
873                                      parent_job_id=parent_job.id)
874        standalone_job = self._create_job(hosts=[1])
875
876        self._check_job_ids(rpc_interface.get_jobs(suite=True), [parent_job])
877        self._check_job_ids(rpc_interface.get_jobs(sub=True), [child_jobs])
878        self._check_job_ids(rpc_interface.get_jobs(standalone=True),
879                            [standalone_job])
880
881
882    def _create_job_helper(self, **kwargs):
883        return rpc_interface.create_job(name='test',
884                                        priority=priorities.Priority.DEFAULT,
885                                        control_file='control file',
886                                        control_type=SERVER, **kwargs)
887
888
889    def test_one_time_hosts(self):
890        job = self._create_job_helper(one_time_hosts=['testhost'])
891        host = models.Host.objects.get(hostname='testhost')
892        self.assertEquals(host.invalid, True)
893        self.assertEquals(host.labels.count(), 0)
894        self.assertEquals(host.aclgroup_set.count(), 0)
895
896
897    def test_create_job_duplicate_hosts(self):
898        self.assertRaises(model_logic.ValidationError, self._create_job_helper,
899                          hosts=[1, 1])
900
901
902    def test_create_unrunnable_metahost_job(self):
903        self.assertRaises(error.NoEligibleHostException,
904                          self._create_job_helper, meta_hosts=['unused'])
905
906
907    def test_create_hostless_job(self):
908        job_id = self._create_job_helper(hostless=True)
909        job = models.Job.objects.get(pk=job_id)
910        queue_entries = job.hostqueueentry_set.all()
911        self.assertEquals(len(queue_entries), 1)
912        self.assertEquals(queue_entries[0].host, None)
913        self.assertEquals(queue_entries[0].meta_host, None)
914
915
916    def _setup_special_tasks(self):
917        host = self.hosts[0]
918
919        job1 = self._create_job(hosts=[1])
920        job2 = self._create_job(hosts=[1])
921
922        entry1 = job1.hostqueueentry_set.all()[0]
923        entry1.update_object(started_on=datetime.datetime(2009, 1, 2),
924                             execution_subdir='host1')
925        entry2 = job2.hostqueueentry_set.all()[0]
926        entry2.update_object(started_on=datetime.datetime(2009, 1, 3),
927                             execution_subdir='host1')
928
929        self.task1 = models.SpecialTask.objects.create(
930                host=host, task=models.SpecialTask.Task.VERIFY,
931                time_started=datetime.datetime(2009, 1, 1), # ran before job 1
932                is_complete=True, requested_by=models.User.current_user())
933        self.task2 = models.SpecialTask.objects.create(
934                host=host, task=models.SpecialTask.Task.VERIFY,
935                queue_entry=entry2, # ran with job 2
936                is_active=True, requested_by=models.User.current_user())
937        self.task3 = models.SpecialTask.objects.create(
938                host=host, task=models.SpecialTask.Task.VERIFY,
939                requested_by=models.User.current_user()) # not yet run
940
941
942    def test_get_special_tasks(self):
943        self._setup_special_tasks()
944        tasks = rpc_interface.get_special_tasks(host__hostname='host1',
945                                                queue_entry__isnull=True)
946        self.assertEquals(len(tasks), 2)
947        self.assertEquals(tasks[0]['task'], models.SpecialTask.Task.VERIFY)
948        self.assertEquals(tasks[0]['is_active'], False)
949        self.assertEquals(tasks[0]['is_complete'], True)
950
951
952    def test_get_latest_special_task(self):
953        # a particular usage of get_special_tasks()
954        self._setup_special_tasks()
955        self.task2.time_started = datetime.datetime(2009, 1, 2)
956        self.task2.save()
957
958        tasks = rpc_interface.get_special_tasks(
959                host__hostname='host1', task=models.SpecialTask.Task.VERIFY,
960                time_started__isnull=False, sort_by=['-time_started'],
961                query_limit=1)
962        self.assertEquals(len(tasks), 1)
963        self.assertEquals(tasks[0]['id'], 2)
964
965
966    def _common_entry_check(self, entry_dict):
967        self.assertEquals(entry_dict['host']['hostname'], 'host1')
968        self.assertEquals(entry_dict['job']['id'], 2)
969
970
971    def test_get_host_queue_entries_and_special_tasks(self):
972        self._setup_special_tasks()
973
974        host = self.hosts[0].id
975        entries_and_tasks = (
976                rpc_interface.get_host_queue_entries_and_special_tasks(host))
977
978        paths = [entry['execution_path'] for entry in entries_and_tasks]
979        self.assertEquals(paths, ['hosts/host1/3-verify',
980                                  '2-autotest_system/host1',
981                                  'hosts/host1/2-verify',
982                                  '1-autotest_system/host1',
983                                  'hosts/host1/1-verify'])
984
985        verify2 = entries_and_tasks[2]
986        self._common_entry_check(verify2)
987        self.assertEquals(verify2['type'], 'Verify')
988        self.assertEquals(verify2['status'], 'Running')
989        self.assertEquals(verify2['execution_path'], 'hosts/host1/2-verify')
990
991        entry2 = entries_and_tasks[1]
992        self._common_entry_check(entry2)
993        self.assertEquals(entry2['type'], 'Job')
994        self.assertEquals(entry2['status'], 'Queued')
995        self.assertEquals(entry2['started_on'], '2009-01-03 00:00:00')
996
997
998    def _create_hqes_and_start_time_index_entries(self):
999        shard = models.Shard.objects.create(hostname='shard')
1000        job = self._create_job(shard=shard, control_file='foo')
1001        HqeStatus = models.HostQueueEntry.Status
1002
1003        models.HostQueueEntry(
1004            id=1, job=job, started_on='2017-01-01',
1005            status=HqeStatus.QUEUED).save()
1006        models.HostQueueEntry(
1007            id=2, job=job, started_on='2017-01-02',
1008            status=HqeStatus.QUEUED).save()
1009        models.HostQueueEntry(
1010            id=3, job=job, started_on='2017-01-03',
1011            status=HqeStatus.QUEUED).save()
1012
1013        models.HostQueueEntryStartTimes(
1014            insert_time='2017-01-03', highest_hqe_id=3).save()
1015        models.HostQueueEntryStartTimes(
1016            insert_time='2017-01-02', highest_hqe_id=2).save()
1017        models.HostQueueEntryStartTimes(
1018            insert_time='2017-01-01', highest_hqe_id=1).save()
1019
1020    def test_get_host_queue_entries_by_insert_time(self):
1021        """Check the insert_time_after and insert_time_before constraints."""
1022        self._create_hqes_and_start_time_index_entries()
1023        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1024            insert_time_after='2017-01-01')
1025        self.assertEquals(len(hqes), 3)
1026
1027        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1028            insert_time_after='2017-01-02')
1029        self.assertEquals(len(hqes), 2)
1030
1031        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1032            insert_time_after='2017-01-03')
1033        self.assertEquals(len(hqes), 1)
1034
1035        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1036            insert_time_before='2017-01-01')
1037        self.assertEquals(len(hqes), 1)
1038
1039        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1040            insert_time_before='2017-01-02')
1041        self.assertEquals(len(hqes), 2)
1042
1043        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1044            insert_time_before='2017-01-03')
1045        self.assertEquals(len(hqes), 3)
1046
1047
1048    def test_get_host_queue_entries_by_insert_time_with_missing_index_row(self):
1049        """Shows that the constraints are approximate.
1050
1051        The query may return rows which are actually outside of the bounds
1052        given, if the index table does not have an entry for the specific time.
1053        """
1054        self._create_hqes_and_start_time_index_entries()
1055        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1056            insert_time_before='2016-12-01')
1057        self.assertEquals(len(hqes), 1)
1058
1059    def test_get_hqe_by_insert_time_with_before_and_after(self):
1060        self._create_hqes_and_start_time_index_entries()
1061        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1062            insert_time_before='2017-01-02',
1063            insert_time_after='2017-01-02')
1064        self.assertEquals(len(hqes), 1)
1065
1066    def test_get_hqe_by_insert_time_and_id_constraint(self):
1067        self._create_hqes_and_start_time_index_entries()
1068        # The time constraint is looser than the id constraint, so the time
1069        # constraint should take precedence.
1070        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1071            insert_time_before='2017-01-02',
1072            id__lte=1)
1073        self.assertEquals(len(hqes), 1)
1074
1075        # Now make the time constraint tighter than the id constraint.
1076        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1077            insert_time_before='2017-01-01',
1078            id__lte=42)
1079        self.assertEquals(len(hqes), 1)
1080
1081    def test_view_invalid_host(self):
1082        # RPCs used by View Host page should work for invalid hosts
1083        self._create_job_helper(hosts=[1])
1084        host = self.hosts[0]
1085        host.delete()
1086
1087        self.assertEquals(1, rpc_interface.get_num_hosts(hostname='host1',
1088                                                         valid_only=False))
1089        data = rpc_interface.get_hosts(hostname='host1', valid_only=False)
1090        self.assertEquals(1, len(data))
1091
1092        self.assertEquals(1, rpc_interface.get_num_host_queue_entries(
1093                host__hostname='host1'))
1094        data = rpc_interface.get_host_queue_entries(host__hostname='host1')
1095        self.assertEquals(1, len(data))
1096
1097        count = rpc_interface.get_num_host_queue_entries_and_special_tasks(
1098                host=host.id)
1099        self.assertEquals(1, count)
1100        data = rpc_interface.get_host_queue_entries_and_special_tasks(
1101                host=host.id)
1102        self.assertEquals(1, len(data))
1103
1104
1105    def test_reverify_hosts(self):
1106        hostname_list = rpc_interface.reverify_hosts(id__in=[1, 2])
1107        self.assertEquals(hostname_list, ['host1', 'host2'])
1108        tasks = rpc_interface.get_special_tasks()
1109        self.assertEquals(len(tasks), 2)
1110        self.assertEquals(set(task['host']['id'] for task in tasks),
1111                          set([1, 2]))
1112
1113        task = tasks[0]
1114        self.assertEquals(task['task'], models.SpecialTask.Task.VERIFY)
1115        self.assertEquals(task['requested_by'], 'autotest_system')
1116
1117
1118    def test_repair_hosts(self):
1119        hostname_list = rpc_interface.repair_hosts(id__in=[1, 2])
1120        self.assertEquals(hostname_list, ['host1', 'host2'])
1121        tasks = rpc_interface.get_special_tasks()
1122        self.assertEquals(len(tasks), 2)
1123        self.assertEquals(set(task['host']['id'] for task in tasks),
1124                          set([1, 2]))
1125
1126        task = tasks[0]
1127        self.assertEquals(task['task'], models.SpecialTask.Task.REPAIR)
1128        self.assertEquals(task['requested_by'], 'autotest_system')
1129
1130
1131    def _modify_host_helper(self, on_shard=False, host_on_shard=False):
1132        shard_hostname = 'shard1'
1133        if on_shard:
1134            global_config.global_config.override_config_value(
1135                'SHARD', 'shard_hostname', shard_hostname)
1136
1137        host = models.Host.objects.all()[0]
1138        if host_on_shard:
1139            shard = models.Shard.objects.create(hostname=shard_hostname)
1140            host.shard = shard
1141            host.save()
1142
1143        self.assertFalse(host.locked)
1144
1145        self.god.stub_class_method(frontend.AFE, 'run')
1146
1147        if host_on_shard and not on_shard:
1148            mock_afe = self.god.create_mock_class_obj(
1149                    frontend_wrappers.RetryingAFE, 'MockAFE')
1150            self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
1151
1152            mock_afe2 = frontend_wrappers.RetryingAFE.expect_new(
1153                    server=shard_hostname, user=None)
1154            mock_afe2.run.expect_call('modify_host_local', id=host.id,
1155                    locked=True, lock_reason='_modify_host_helper lock',
1156                    lock_time=datetime.datetime(2015, 12, 15))
1157        elif on_shard:
1158            mock_afe = self.god.create_mock_class_obj(
1159                    frontend_wrappers.RetryingAFE, 'MockAFE')
1160            self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
1161
1162            mock_afe2 = frontend_wrappers.RetryingAFE.expect_new(
1163                    server=server_utils.get_global_afe_hostname(), user=None)
1164            mock_afe2.run.expect_call('modify_host', id=host.id,
1165                    locked=True, lock_reason='_modify_host_helper lock',
1166                    lock_time=datetime.datetime(2015, 12, 15))
1167
1168        rpc_interface.modify_host(id=host.id, locked=True,
1169                                  lock_reason='_modify_host_helper lock',
1170                                  lock_time=datetime.datetime(2015, 12, 15))
1171
1172        host = models.Host.objects.get(pk=host.id)
1173        if on_shard:
1174            # modify_host on shard does nothing but routing the RPC to main.
1175            self.assertFalse(host.locked)
1176        else:
1177            self.assertTrue(host.locked)
1178        self.god.check_playback()
1179
1180
1181    def test_modify_host_on_main_host_on_main(self):
1182        """Call modify_host to main for host in main."""
1183        self._modify_host_helper()
1184
1185
1186    def test_modify_host_on_main_host_on_shard(self):
1187        """Call modify_host to main for host in shard."""
1188        self._modify_host_helper(host_on_shard=True)
1189
1190
1191    def test_modify_host_on_shard(self):
1192        """Call modify_host to shard for host in shard."""
1193        self._modify_host_helper(on_shard=True, host_on_shard=True)
1194
1195
1196    def test_modify_hosts_on_main_host_on_shard(self):
1197        """Ensure calls to modify_hosts are correctly forwarded to shards."""
1198        host1 = models.Host.objects.all()[0]
1199        host2 = models.Host.objects.all()[1]
1200
1201        shard1 = models.Shard.objects.create(hostname='shard1')
1202        host1.shard = shard1
1203        host1.save()
1204
1205        shard2 = models.Shard.objects.create(hostname='shard2')
1206        host2.shard = shard2
1207        host2.save()
1208
1209        self.assertFalse(host1.locked)
1210        self.assertFalse(host2.locked)
1211
1212        mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
1213                                                  'MockAFE')
1214        self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
1215
1216        # The statuses of one host might differ on main and shard.
1217        # Filters are always applied on the main. So the host on the shard
1218        # will be affected no matter what his status is.
1219        filters_to_use = {'status': 'Ready'}
1220
1221        mock_afe2 = frontend_wrappers.RetryingAFE.expect_new(
1222                server='shard2', user=None)
1223        mock_afe2.run.expect_call(
1224            'modify_hosts_local',
1225            host_filter_data={'id__in': [shard1.id, shard2.id]},
1226            update_data={'locked': True,
1227                         'lock_reason': 'Testing forward to shard',
1228                         'lock_time' : datetime.datetime(2015, 12, 15) })
1229
1230        mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
1231                server='shard1', user=None)
1232        mock_afe1.run.expect_call(
1233            'modify_hosts_local',
1234            host_filter_data={'id__in': [shard1.id, shard2.id]},
1235            update_data={'locked': True,
1236                         'lock_reason': 'Testing forward to shard',
1237                         'lock_time' : datetime.datetime(2015, 12, 15)})
1238
1239        rpc_interface.modify_hosts(
1240                host_filter_data={'status': 'Ready'},
1241                update_data={'locked': True,
1242                             'lock_reason': 'Testing forward to shard',
1243                             'lock_time' : datetime.datetime(2015, 12, 15) })
1244
1245        host1 = models.Host.objects.get(pk=host1.id)
1246        self.assertTrue(host1.locked)
1247        host2 = models.Host.objects.get(pk=host2.id)
1248        self.assertTrue(host2.locked)
1249        self.god.check_playback()
1250
1251
1252    def test_delete_host(self):
1253        """Ensure an RPC is made on delete a host, if it is on a shard."""
1254        host1 = models.Host.objects.all()[0]
1255        shard1 = models.Shard.objects.create(hostname='shard1')
1256        host1.shard = shard1
1257        host1.save()
1258        host1_id = host1.id
1259
1260        mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
1261                                                 'MockAFE')
1262        self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
1263
1264        mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
1265                server='shard1', user=None)
1266        mock_afe1.run.expect_call('delete_host', id=host1.id)
1267
1268        rpc_interface.delete_host(id=host1.id)
1269
1270        self.assertRaises(models.Host.DoesNotExist,
1271                          models.Host.smart_get, host1_id)
1272
1273        self.god.check_playback()
1274
1275
1276    def test_delete_shard(self):
1277        """Ensure the RPC can delete a shard."""
1278        host1 = models.Host.objects.all()[0]
1279        shard1 = models.Shard.objects.create(hostname='shard1')
1280        host1.shard = shard1
1281        host1.save()
1282
1283        rpc_interface.delete_shard(hostname=shard1.hostname)
1284
1285        host1 = models.Host.smart_get(host1.id)
1286        self.assertIsNone(host1.shard)
1287        self.assertRaises(models.Shard.DoesNotExist,
1288                          models.Shard.smart_get, shard1.hostname)
1289
1290
1291    def test_modify_label(self):
1292        label1 = models.Label.objects.all()[0]
1293        self.assertEqual(label1.invalid, 0)
1294
1295        host2 = models.Host.objects.all()[1]
1296        shard1 = models.Shard.objects.create(hostname='shard1')
1297        host2.shard = shard1
1298        host2.labels.add(label1)
1299        host2.save()
1300
1301        mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
1302                                                  'MockAFE')
1303        self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
1304
1305        mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
1306                server='shard1', user=None)
1307        mock_afe1.run.expect_call('modify_label', id=label1.id, invalid=1)
1308
1309        rpc_interface.modify_label(label1.id, invalid=1)
1310
1311        self.assertEqual(models.Label.objects.all()[0].invalid, 1)
1312        self.god.check_playback()
1313
1314
1315    def test_delete_label(self):
1316        label1 = models.Label.objects.all()[0]
1317
1318        host2 = models.Host.objects.all()[1]
1319        shard1 = models.Shard.objects.create(hostname='shard1')
1320        host2.shard = shard1
1321        host2.labels.add(label1)
1322        host2.save()
1323
1324        mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
1325                                                  'MockAFE')
1326        self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
1327
1328        mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
1329                server='shard1', user=None)
1330        mock_afe1.run.expect_call('delete_label', id=label1.id)
1331
1332        rpc_interface.delete_label(id=label1.id)
1333
1334        self.assertRaises(models.Label.DoesNotExist,
1335                          models.Label.smart_get, label1.id)
1336        self.god.check_playback()
1337
1338
1339    def test_get_image_for_job_with_keyval_build(self):
1340        keyval_dict = {'build': 'cool-image'}
1341        job_id = rpc_interface.create_job(name='test',
1342                                          priority=priorities.Priority.DEFAULT,
1343                                          control_file='foo',
1344                                          control_type=CLIENT,
1345                                          hosts=['host1'],
1346                                          keyvals=keyval_dict)
1347        job = models.Job.objects.get(id=job_id)
1348        self.assertIsNotNone(job)
1349        image = rpc_interface._get_image_for_job(job, True)
1350        self.assertEquals('cool-image', image)
1351
1352
1353    def test_get_image_for_job_with_keyval_builds(self):
1354        keyval_dict = {'builds': {'cros-version': 'cool-image'}}
1355        job_id = rpc_interface.create_job(name='test',
1356                                          priority=priorities.Priority.DEFAULT,
1357                                          control_file='foo',
1358                                          control_type=CLIENT,
1359                                          hosts=['host1'],
1360                                          keyvals=keyval_dict)
1361        job = models.Job.objects.get(id=job_id)
1362        self.assertIsNotNone(job)
1363        image = rpc_interface._get_image_for_job(job, True)
1364        self.assertEquals('cool-image', image)
1365
1366
1367    def test_get_image_for_job_with_control_build(self):
1368        CONTROL_FILE = """build='cool-image'
1369        """
1370        job_id = rpc_interface.create_job(name='test',
1371                                          priority=priorities.Priority.DEFAULT,
1372                                          control_file='foo',
1373                                          control_type=CLIENT,
1374                                          hosts=['host1'])
1375        job = models.Job.objects.get(id=job_id)
1376        self.assertIsNotNone(job)
1377        job.control_file = CONTROL_FILE
1378        image = rpc_interface._get_image_for_job(job, True)
1379        self.assertEquals('cool-image', image)
1380
1381
1382    def test_get_image_for_job_with_control_builds(self):
1383        CONTROL_FILE = """builds={'cros-version': 'cool-image'}
1384        """
1385        job_id = rpc_interface.create_job(name='test',
1386                                          priority=priorities.Priority.DEFAULT,
1387                                          control_file='foo',
1388                                          control_type=CLIENT,
1389                                          hosts=['host1'])
1390        job = models.Job.objects.get(id=job_id)
1391        self.assertIsNotNone(job)
1392        job.control_file = CONTROL_FILE
1393        image = rpc_interface._get_image_for_job(job, True)
1394        self.assertEquals('cool-image', image)
1395
1396
1397class ExtraRpcInterfaceTest(frontend_test_utils.FrontendTestMixin,
1398                            ShardHeartbeatTest):
1399    """Unit tests for functions originally in site_rpc_interface.py.
1400
1401    @var _NAME: fake suite name.
1402    @var _BOARD: fake board to reimage.
1403    @var _BUILD: fake build with which to reimage.
1404    @var _PRIORITY: fake priority with which to reimage.
1405    """
1406    _NAME = 'name'
1407    _BOARD = 'link'
1408    _BUILD = 'link-release/R36-5812.0.0'
1409    _BUILDS = {provision.CROS_VERSION_PREFIX: _BUILD}
1410    _PRIORITY = priorities.Priority.DEFAULT
1411    _TIMEOUT = 24
1412
1413
1414    def setUp(self):
1415        super(ExtraRpcInterfaceTest, self).setUp()
1416        self._SUITE_NAME = suite_common.canonicalize_suite_name(
1417            self._NAME)
1418        self.dev_server = self.mox.CreateMock(dev_server.ImageServer)
1419        self._frontend_common_setup(fill_data=False)
1420
1421
1422    def tearDown(self):
1423        self._frontend_common_teardown()
1424
1425
1426    def _setupDevserver(self):
1427        self.mox.StubOutClassWithMocks(dev_server, 'ImageServer')
1428        dev_server.resolve(self._BUILD).AndReturn(self.dev_server)
1429
1430
1431    def _mockDevServerGetter(self, get_control_file=True):
1432        self._setupDevserver()
1433        if get_control_file:
1434          self.getter = self.mox.CreateMock(
1435              control_file_getter.DevServerGetter)
1436          self.mox.StubOutWithMock(control_file_getter.DevServerGetter,
1437                                   'create')
1438          control_file_getter.DevServerGetter.create(
1439              mox.IgnoreArg(), mox.IgnoreArg()).AndReturn(self.getter)
1440
1441
1442    def _mockRpcUtils(self, to_return, control_file_substring=''):
1443        """Fake out the autotest rpc_utils module with a mockable class.
1444
1445        @param to_return: the value that rpc_utils.create_job_common() should
1446                          be mocked out to return.
1447        @param control_file_substring: A substring that is expected to appear
1448                                       in the control file output string that
1449                                       is passed to create_job_common.
1450                                       Default: ''
1451        """
1452        download_started_time = constants.DOWNLOAD_STARTED_TIME
1453        payload_finished_time = constants.PAYLOAD_FINISHED_TIME
1454        self.mox.StubOutWithMock(rpc_utils, 'create_job_common')
1455        rpc_utils.create_job_common(mox.And(mox.StrContains(self._NAME),
1456                                    mox.StrContains(self._BUILD)),
1457                            priority=self._PRIORITY,
1458                            timeout_mins=self._TIMEOUT*60,
1459                            max_runtime_mins=self._TIMEOUT*60,
1460                            control_type='Server',
1461                            control_file=mox.And(mox.StrContains(self._BOARD),
1462                                                 mox.StrContains(self._BUILD),
1463                                                 mox.StrContains(
1464                                                     control_file_substring)),
1465                            hostless=True,
1466                            keyvals=mox.And(mox.In(download_started_time),
1467                                            mox.In(payload_finished_time))
1468                            ).AndReturn(to_return)
1469
1470
1471    def testStageBuildFail(self):
1472        """Ensure that a failure to stage the desired build fails the RPC."""
1473        self._setupDevserver()
1474
1475        self.dev_server.hostname = 'mox_url'
1476        self.dev_server.stage_artifacts(
1477                image=self._BUILD,
1478                artifacts=['test_suites', 'control_files']).AndRaise(
1479                dev_server.DevServerException())
1480        self.mox.ReplayAll()
1481        self.assertRaises(error.StageControlFileFailure,
1482                          rpc_interface.create_suite_job,
1483                          name=self._NAME,
1484                          board=self._BOARD,
1485                          builds=self._BUILDS,
1486                          pool=None)
1487
1488
1489    def testGetControlFileFail(self):
1490        """Ensure that a failure to get needed control file fails the RPC."""
1491        self._mockDevServerGetter()
1492
1493        self.dev_server.hostname = 'mox_url'
1494        self.dev_server.stage_artifacts(
1495                image=self._BUILD,
1496                artifacts=['test_suites', 'control_files']).AndReturn(True)
1497
1498        self.getter.get_control_file_contents_by_name(
1499            self._SUITE_NAME).AndReturn(None)
1500        self.mox.ReplayAll()
1501        self.assertRaises(error.ControlFileEmpty,
1502                          rpc_interface.create_suite_job,
1503                          name=self._NAME,
1504                          board=self._BOARD,
1505                          builds=self._BUILDS,
1506                          pool=None)
1507
1508
1509    def testGetControlFileListFail(self):
1510        """Ensure that a failure to get needed control file fails the RPC."""
1511        self._mockDevServerGetter()
1512
1513        self.dev_server.hostname = 'mox_url'
1514        self.dev_server.stage_artifacts(
1515                image=self._BUILD,
1516                artifacts=['test_suites', 'control_files']).AndReturn(True)
1517
1518        self.getter.get_control_file_contents_by_name(
1519            self._SUITE_NAME).AndRaise(error.NoControlFileList())
1520        self.mox.ReplayAll()
1521        self.assertRaises(error.NoControlFileList,
1522                          rpc_interface.create_suite_job,
1523                          name=self._NAME,
1524                          board=self._BOARD,
1525                          builds=self._BUILDS,
1526                          pool=None)
1527
1528
1529    def testCreateSuiteJobFail(self):
1530        """Ensure that failure to schedule the suite job fails the RPC."""
1531        self._mockDevServerGetter()
1532
1533        self.dev_server.hostname = 'mox_url'
1534        self.dev_server.stage_artifacts(
1535                image=self._BUILD,
1536                artifacts=['test_suites', 'control_files']).AndReturn(True)
1537
1538        self.getter.get_control_file_contents_by_name(
1539            self._SUITE_NAME).AndReturn('f')
1540
1541        self.dev_server.url().AndReturn('mox_url')
1542        self._mockRpcUtils(-1)
1543        self.mox.ReplayAll()
1544        self.assertEquals(
1545            rpc_interface.create_suite_job(name=self._NAME,
1546                                           board=self._BOARD,
1547                                           builds=self._BUILDS, pool=None),
1548            -1)
1549
1550
1551    def testCreateSuiteJobSuccess(self):
1552        """Ensures that success results in a successful RPC."""
1553        self._mockDevServerGetter()
1554
1555        self.dev_server.hostname = 'mox_url'
1556        self.dev_server.stage_artifacts(
1557                image=self._BUILD,
1558                artifacts=['test_suites', 'control_files']).AndReturn(True)
1559
1560        self.getter.get_control_file_contents_by_name(
1561            self._SUITE_NAME).AndReturn('f')
1562
1563        self.dev_server.url().AndReturn('mox_url')
1564        job_id = 5
1565        self._mockRpcUtils(job_id)
1566        self.mox.ReplayAll()
1567        self.assertEquals(
1568            rpc_interface.create_suite_job(name=self._NAME,
1569                                           board=self._BOARD,
1570                                           builds=self._BUILDS,
1571                                           pool=None),
1572            job_id)
1573
1574
1575    def testCreateSuiteJobNoHostCheckSuccess(self):
1576        """Ensures that success results in a successful RPC."""
1577        self._mockDevServerGetter()
1578
1579        self.dev_server.hostname = 'mox_url'
1580        self.dev_server.stage_artifacts(
1581                image=self._BUILD,
1582                artifacts=['test_suites', 'control_files']).AndReturn(True)
1583
1584        self.getter.get_control_file_contents_by_name(
1585            self._SUITE_NAME).AndReturn('f')
1586
1587        self.dev_server.url().AndReturn('mox_url')
1588        job_id = 5
1589        self._mockRpcUtils(job_id)
1590        self.mox.ReplayAll()
1591        self.assertEquals(
1592          rpc_interface.create_suite_job(name=self._NAME,
1593                                         board=self._BOARD,
1594                                         builds=self._BUILDS,
1595                                         pool=None, check_hosts=False),
1596          job_id)
1597
1598
1599    def testCreateSuiteJobControlFileSupplied(self):
1600        """Ensure we can supply the control file to create_suite_job."""
1601        self._mockDevServerGetter(get_control_file=False)
1602
1603        self.dev_server.hostname = 'mox_url'
1604        self.dev_server.stage_artifacts(
1605                image=self._BUILD,
1606                artifacts=['test_suites', 'control_files']).AndReturn(True)
1607        self.dev_server.url().AndReturn('mox_url')
1608        job_id = 5
1609        self._mockRpcUtils(job_id)
1610        self.mox.ReplayAll()
1611        self.assertEquals(
1612            rpc_interface.create_suite_job(name='%s/%s' % (self._NAME,
1613                                                           self._BUILD),
1614                                           board=None,
1615                                           builds=self._BUILDS,
1616                                           pool=None,
1617                                           control_file='CONTROL FILE'),
1618            job_id)
1619
1620
1621    def _get_records_for_sending_to_main(self):
1622        return [{'control_file': 'foo',
1623                 'control_type': 1,
1624                 'created_on': datetime.datetime(2014, 8, 21),
1625                 'drone_set': None,
1626                 'email_list': '',
1627                 'max_runtime_hrs': 72,
1628                 'max_runtime_mins': 1440,
1629                 'name': 'dummy',
1630                 'owner': 'autotest_system',
1631                 'parse_failed_repair': True,
1632                 'priority': 40,
1633                 'reboot_after': 0,
1634                 'reboot_before': 1,
1635                 'run_reset': True,
1636                 'run_verify': False,
1637                 'synch_count': 0,
1638                 'test_retry': 0,
1639                 'timeout': 24,
1640                 'timeout_mins': 1440,
1641                 'id': 1
1642                 }], [{
1643                    'aborted': False,
1644                    'active': False,
1645                    'complete': False,
1646                    'deleted': False,
1647                    'execution_subdir': '',
1648                    'finished_on': None,
1649                    'started_on': None,
1650                    'status': 'Queued',
1651                    'id': 1
1652                }]
1653
1654
1655    def _send_records_to_main_helper(
1656        self, jobs, hqes, shard_hostname='host1',
1657        exception_to_throw=error.UnallowedRecordsSentToMain, aborted=False):
1658        job_id = rpc_interface.create_job(
1659                name='dummy',
1660                priority=self._PRIORITY,
1661                control_file='foo',
1662                control_type=SERVER,
1663                hostless=True)
1664        job = models.Job.objects.get(pk=job_id)
1665        shard = models.Shard.objects.create(hostname='host1')
1666        job.shard = shard
1667        job.save()
1668
1669        if aborted:
1670            job.hostqueueentry_set.update(aborted=True)
1671            job.shard = None
1672            job.save()
1673
1674        hqe = job.hostqueueentry_set.all()[0]
1675        if not exception_to_throw:
1676            self._do_heartbeat_and_assert_response(
1677                shard_hostname=shard_hostname,
1678                upload_jobs=jobs, upload_hqes=hqes)
1679        else:
1680            self.assertRaises(
1681                exception_to_throw,
1682                self._do_heartbeat_and_assert_response,
1683                shard_hostname=shard_hostname,
1684                upload_jobs=jobs, upload_hqes=hqes)
1685
1686
1687    def testSendingRecordsToMain(self):
1688        """Send records to the main and ensure they are persisted."""
1689        jobs, hqes = self._get_records_for_sending_to_main()
1690        hqes[0]['status'] = 'Completed'
1691        self._send_records_to_main_helper(
1692            jobs=jobs, hqes=hqes, exception_to_throw=None)
1693
1694        # Check the entry was actually written to db
1695        self.assertEqual(models.HostQueueEntry.objects.all()[0].status,
1696                         'Completed')
1697
1698
1699    def testSendingRecordsToMainAbortedOnMain(self):
1700        """Send records to the main and ensure they are persisted."""
1701        jobs, hqes = self._get_records_for_sending_to_main()
1702        hqes[0]['status'] = 'Completed'
1703        self._send_records_to_main_helper(
1704            jobs=jobs, hqes=hqes, exception_to_throw=None, aborted=True)
1705
1706        # Check the entry was actually written to db
1707        self.assertEqual(models.HostQueueEntry.objects.all()[0].status,
1708                         'Completed')
1709
1710
1711    def testSendingRecordsToMainJobAssignedToDifferentShard(self):
1712        """Ensure records belonging to different shard are silently rejected."""
1713        shard1 = models.Shard.objects.create(hostname='shard1')
1714        shard2 = models.Shard.objects.create(hostname='shard2')
1715        job1 = self._create_job(shard=shard1, control_file='foo1')
1716        job2 = self._create_job(shard=shard2, control_file='foo2')
1717        job1_id = job1.id
1718        job2_id = job2.id
1719        hqe1 = models.HostQueueEntry.objects.create(job=job1)
1720        hqe2 = models.HostQueueEntry.objects.create(job=job2)
1721        hqe1_id = hqe1.id
1722        hqe2_id = hqe2.id
1723        job1_record = job1.serialize(include_dependencies=False)
1724        job2_record = job2.serialize(include_dependencies=False)
1725        hqe1_record = hqe1.serialize(include_dependencies=False)
1726        hqe2_record = hqe2.serialize(include_dependencies=False)
1727
1728        # Prepare a bogus job record update from the wrong shard. The update
1729        # should not throw an exception. Non-bogus jobs in the same update
1730        # should happily update.
1731        job1_record.update({'control_file': 'bar1'})
1732        job2_record.update({'control_file': 'bar2'})
1733        hqe1_record.update({'status': 'Aborted'})
1734        hqe2_record.update({'status': 'Aborted'})
1735        self._do_heartbeat_and_assert_response(
1736            shard_hostname='shard2', upload_jobs=[job1_record, job2_record],
1737            upload_hqes=[hqe1_record, hqe2_record])
1738
1739        # Job and HQE record for wrong job should not be modified, because the
1740        # rpc came from the wrong shard. Job and HQE record for valid job are
1741        # modified.
1742        self.assertEqual(models.Job.objects.get(id=job1_id).control_file,
1743                         'foo1')
1744        self.assertEqual(models.Job.objects.get(id=job2_id).control_file,
1745                         'bar2')
1746        self.assertEqual(models.HostQueueEntry.objects.get(id=hqe1_id).status,
1747                         '')
1748        self.assertEqual(models.HostQueueEntry.objects.get(id=hqe2_id).status,
1749                         'Aborted')
1750
1751
1752    def testSendingRecordsToMainNotExistingJob(self):
1753        """Ensure update for non existing job gets rejected."""
1754        jobs, hqes = self._get_records_for_sending_to_main()
1755        jobs[0]['id'] = 3
1756
1757        self._send_records_to_main_helper(
1758            jobs=jobs, hqes=hqes)
1759
1760
1761    def _createShardAndHostWithLabel(self, shard_hostname='shard1',
1762                                     host_hostname='host1',
1763                                     label_name='board:lumpy'):
1764        """Create a label, host, shard, and assign host to shard."""
1765        try:
1766            label = models.Label.objects.create(name=label_name)
1767        except:
1768            label = models.Label.smart_get(label_name)
1769
1770        shard = models.Shard.objects.create(hostname=shard_hostname)
1771        shard.labels.add(label)
1772
1773        host = models.Host.objects.create(hostname=host_hostname, leased=False,
1774                                          shard=shard)
1775        host.labels.add(label)
1776
1777        return shard, host, label
1778
1779
1780    def testShardLabelRemovalInvalid(self):
1781        """Ensure you cannot remove the wrong label from shard."""
1782        shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
1783        stumpy_label = models.Label.objects.create(
1784                name='board:stumpy', platform=True)
1785        with self.assertRaises(error.RPCException):
1786            rpc_interface.remove_board_from_shard(
1787                    shard1.hostname, stumpy_label.name)
1788
1789
1790    def testShardHeartbeatLabelRemoval(self):
1791        """Ensure label removal from shard works."""
1792        shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
1793
1794        self.assertEqual(host1.shard, shard1)
1795        self.assertItemsEqual(shard1.labels.all(), [lumpy_label])
1796        rpc_interface.remove_board_from_shard(
1797                shard1.hostname, lumpy_label.name)
1798        host1 = models.Host.smart_get(host1.id)
1799        shard1 = models.Shard.smart_get(shard1.id)
1800        self.assertEqual(host1.shard, None)
1801        self.assertItemsEqual(shard1.labels.all(), [])
1802
1803
1804    def testCreateListShard(self):
1805        """Retrieve a list of all shards."""
1806        lumpy_label = models.Label.objects.create(name='board:lumpy',
1807                                                  platform=True)
1808        stumpy_label = models.Label.objects.create(name='board:stumpy',
1809                                                  platform=True)
1810        peppy_label = models.Label.objects.create(name='board:peppy',
1811                                                  platform=True)
1812
1813        shard_id = rpc_interface.add_shard(
1814            hostname='host1', labels='board:lumpy,board:stumpy')
1815        self.assertRaises(error.RPCException,
1816                          rpc_interface.add_shard,
1817                          hostname='host1', labels='board:lumpy,board:stumpy')
1818        self.assertRaises(model_logic.ValidationError,
1819                          rpc_interface.add_shard,
1820                          hostname='host1', labels='board:peppy')
1821        shard = models.Shard.objects.get(pk=shard_id)
1822        self.assertEqual(shard.hostname, 'host1')
1823        self.assertEqual(shard.labels.values_list('pk')[0], (lumpy_label.id,))
1824        self.assertEqual(shard.labels.values_list('pk')[1], (stumpy_label.id,))
1825
1826        self.assertEqual(rpc_interface.get_shards(),
1827                         [{'labels': ['board:lumpy','board:stumpy'],
1828                           'hostname': 'host1',
1829                           'id': 1}])
1830
1831
1832    def testAddBoardsToShard(self):
1833        """Add boards to a given shard."""
1834        shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
1835        stumpy_label = models.Label.objects.create(name='board:stumpy',
1836                                                   platform=True)
1837        shard_id = rpc_interface.add_board_to_shard(
1838            hostname='shard1', labels='board:stumpy')
1839        # Test whether raise exception when board label does not exist.
1840        self.assertRaises(models.Label.DoesNotExist,
1841                          rpc_interface.add_board_to_shard,
1842                          hostname='shard1', labels='board:test')
1843        # Test whether raise exception when board already sharded.
1844        self.assertRaises(error.RPCException,
1845                          rpc_interface.add_board_to_shard,
1846                          hostname='shard1', labels='board:lumpy')
1847        shard = models.Shard.objects.get(pk=shard_id)
1848        self.assertEqual(shard.hostname, 'shard1')
1849        self.assertEqual(shard.labels.values_list('pk')[0], (lumpy_label.id,))
1850        self.assertEqual(shard.labels.values_list('pk')[1], (stumpy_label.id,))
1851
1852        self.assertEqual(rpc_interface.get_shards(),
1853                         [{'labels': ['board:lumpy','board:stumpy'],
1854                           'hostname': 'shard1',
1855                           'id': 1}])
1856
1857
1858    def testShardHeartbeatFetchHostlessJob(self):
1859        shard1, host1, label1 = self._createShardAndHostWithLabel()
1860        self._testShardHeartbeatFetchHostlessJobHelper(host1)
1861
1862
1863    def testShardHeartbeatIncorrectHosts(self):
1864        shard1, host1, label1 = self._createShardAndHostWithLabel()
1865        self._testShardHeartbeatIncorrectHostsHelper(host1)
1866
1867
1868    def testShardHeartbeatLabelRemovalRace(self):
1869        shard1, host1, label1 = self._createShardAndHostWithLabel()
1870        self._testShardHeartbeatLabelRemovalRaceHelper(shard1, host1, label1)
1871
1872
1873    def testShardRetrieveJobs(self):
1874        shard1, host1, label1 = self._createShardAndHostWithLabel()
1875        shard2, host2, label2 = self._createShardAndHostWithLabel(
1876                'shard2', 'host2', 'board:grumpy')
1877        self._testShardRetrieveJobsHelper(shard1, host1, label1,
1878                                          shard2, host2, label2)
1879
1880
1881    def testResendJobsAfterFailedHeartbeat(self):
1882        shard1, host1, label1 = self._createShardAndHostWithLabel()
1883        self._testResendJobsAfterFailedHeartbeatHelper(shard1, host1, label1)
1884
1885
1886    def testResendHostsAfterFailedHeartbeat(self):
1887        shard1, host1, label1 = self._createShardAndHostWithLabel()
1888        self._testResendHostsAfterFailedHeartbeatHelper(host1)
1889
1890
1891if __name__ == '__main__':
1892    unittest.main()
1893