1#!/usr/bin/env python2
2# pylint: disable=missing-docstring
3
4import datetime
5import mox
6import unittest
7
8import common
9from autotest_lib.client.common_lib import control_data
10from autotest_lib.client.common_lib import error
11from autotest_lib.client.common_lib import global_config
12from autotest_lib.client.common_lib import priorities
13from autotest_lib.client.common_lib.cros import dev_server
14from autotest_lib.client.common_lib.test_utils import mock
15from autotest_lib.frontend import setup_django_environment
16from autotest_lib.frontend.afe import frontend_test_utils
17from autotest_lib.frontend.afe import model_logic
18from autotest_lib.frontend.afe import models
19from autotest_lib.frontend.afe import rpc_interface
20from autotest_lib.frontend.afe import rpc_utils
21from autotest_lib.server import frontend
22from autotest_lib.server import utils as server_utils
23from autotest_lib.server.cros import provision
24from autotest_lib.server.cros.dynamic_suite import constants
25from autotest_lib.server.cros.dynamic_suite import control_file_getter
26from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
27
28CLIENT = control_data.CONTROL_TYPE_NAMES.CLIENT
29SERVER = control_data.CONTROL_TYPE_NAMES.SERVER
30
31_hqe_status = models.HostQueueEntry.Status
32
33
34class ShardHeartbeatTest(mox.MoxTestBase, unittest.TestCase):
35
36    _PRIORITY = priorities.Priority.DEFAULT
37
38    def _do_heartbeat_and_assert_response(self, shard_hostname='shard1',
39                                          upload_jobs=(), upload_hqes=(),
40                                          known_jobs=(), known_hosts=(),
41                                          **kwargs):
42        known_job_ids = [job.id for job in known_jobs]
43        known_host_ids = [host.id for host in known_hosts]
44        known_host_statuses = [host.status for host in known_hosts]
45
46        retval = rpc_interface.shard_heartbeat(
47            shard_hostname=shard_hostname,
48            jobs=upload_jobs, hqes=upload_hqes,
49            known_job_ids=known_job_ids, known_host_ids=known_host_ids,
50            known_host_statuses=known_host_statuses)
51
52        self._assert_shard_heartbeat_response(shard_hostname, retval,
53                                              **kwargs)
54
55        return shard_hostname
56
57
58    def _assert_shard_heartbeat_response(self, shard_hostname, retval, jobs=[],
59                                         hosts=[], hqes=[],
60                                         incorrect_host_ids=[]):
61
62        retval_hosts, retval_jobs = retval['hosts'], retval['jobs']
63        retval_incorrect_hosts = retval['incorrect_host_ids']
64
65        expected_jobs = [
66            (job.id, job.name, shard_hostname) for job in jobs]
67        returned_jobs = [(job['id'], job['name'], job['shard']['hostname'])
68                         for job in retval_jobs]
69        self.assertEqual(returned_jobs, expected_jobs)
70
71        expected_hosts = [(host.id, host.hostname) for host in hosts]
72        returned_hosts = [(host['id'], host['hostname'])
73                          for host in retval_hosts]
74        self.assertEqual(returned_hosts, expected_hosts)
75
76        retval_hqes = []
77        for job in retval_jobs:
78            retval_hqes += job['hostqueueentry_set']
79
80        expected_hqes = [(hqe.id) for hqe in hqes]
81        returned_hqes = [(hqe['id']) for hqe in retval_hqes]
82        self.assertEqual(returned_hqes, expected_hqes)
83
84        self.assertEqual(retval_incorrect_hosts, incorrect_host_ids)
85
86
87    def _createJobForLabel(self, label):
88        job_id = rpc_interface.create_job(name='dummy', priority=self._PRIORITY,
89                                          control_file='foo',
90                                          control_type=CLIENT,
91                                          meta_hosts=[label.name],
92                                          dependencies=(label.name,))
93        return models.Job.objects.get(id=job_id)
94
95
96    def _testShardHeartbeatFetchHostlessJobHelper(self, host1):
97        """Create a hostless job and ensure it's not assigned to a shard."""
98        label2 = models.Label.objects.create(name='bluetooth', platform=False)
99
100        job1 = self._create_job(hostless=True)
101
102        # Hostless jobs should be executed by the global scheduler.
103        self._do_heartbeat_and_assert_response(hosts=[host1])
104
105
106    def _testShardHeartbeatIncorrectHostsHelper(self, host1):
107        """Ensure that hosts that don't belong to shard are determined."""
108        host2 = models.Host.objects.create(hostname='test_host2', leased=False)
109
110        # host2 should not belong to shard1. Ensure that if shard1 thinks host2
111        # is a known host, then it is returned as invalid.
112        self._do_heartbeat_and_assert_response(known_hosts=[host1, host2],
113                                               incorrect_host_ids=[host2.id])
114
115
116    def _testShardHeartbeatLabelRemovalRaceHelper(self, shard1, host1, label1):
117        """Ensure correctness if label removed during heartbeat."""
118        host2 = models.Host.objects.create(hostname='test_host2', leased=False)
119        host2.labels.add(label1)
120        self.assertEqual(host2.shard, None)
121
122        # In the middle of the assign_to_shard call, remove label1 from shard1.
123        self.mox.StubOutWithMock(models.Host, '_assign_to_shard_nothing_helper')
124        def remove_label():
125            rpc_interface.remove_board_from_shard(shard1.hostname, label1.name)
126
127        models.Host._assign_to_shard_nothing_helper().WithSideEffects(
128            remove_label)
129        self.mox.ReplayAll()
130
131        self._do_heartbeat_and_assert_response(
132            known_hosts=[host1], hosts=[], incorrect_host_ids=[host1.id])
133        host2 = models.Host.smart_get(host2.id)
134        self.assertEqual(host2.shard, None)
135
136
137    def _testShardRetrieveJobsHelper(self, shard1, host1, label1, shard2,
138                                     host2, label2):
139        """Create jobs and retrieve them."""
140        # should never be returned by heartbeat
141        leased_host = models.Host.objects.create(hostname='leased_host',
142                                                 leased=True)
143
144        leased_host.labels.add(label1)
145
146        job1 = self._createJobForLabel(label1)
147
148        job2 = self._createJobForLabel(label2)
149
150        job_completed = self._createJobForLabel(label1)
151        # Job is already being run, so don't sync it
152        job_completed.hostqueueentry_set.update(complete=True)
153        job_completed.hostqueueentry_set.create(complete=False)
154
155        job_active = self._createJobForLabel(label1)
156        # Job is already started, so don't sync it
157        job_active.hostqueueentry_set.update(active=True)
158        job_active.hostqueueentry_set.create(complete=False, active=False)
159
160        self._do_heartbeat_and_assert_response(
161            jobs=[job1], hosts=[host1], hqes=job1.hostqueueentry_set.all())
162
163        self._do_heartbeat_and_assert_response(
164            shard_hostname=shard2.hostname,
165            jobs=[job2], hosts=[host2], hqes=job2.hostqueueentry_set.all())
166
167        host3 = models.Host.objects.create(hostname='test_host3', leased=False)
168        host3.labels.add(label1)
169
170        self._do_heartbeat_and_assert_response(
171            known_jobs=[job1], known_hosts=[host1], hosts=[host3])
172
173
174    def _testResendJobsAfterFailedHeartbeatHelper(self, shard1, host1, label1):
175        """Create jobs, retrieve them, fail on client, fetch them again."""
176        job1 = self._createJobForLabel(label1)
177
178        self._do_heartbeat_and_assert_response(
179            jobs=[job1],
180            hqes=job1.hostqueueentry_set.all(), hosts=[host1])
181
182        # Make sure it's resubmitted by sending last_job=None again
183        self._do_heartbeat_and_assert_response(
184            known_hosts=[host1],
185            jobs=[job1], hqes=job1.hostqueueentry_set.all(), hosts=[])
186
187        # Now it worked, make sure it's not sent again
188        self._do_heartbeat_and_assert_response(
189            known_jobs=[job1], known_hosts=[host1])
190
191        job1 = models.Job.objects.get(pk=job1.id)
192        job1.hostqueueentry_set.all().update(complete=True)
193
194        # Job is completed, make sure it's not sent again
195        self._do_heartbeat_and_assert_response(
196            known_hosts=[host1])
197
198        job2 = self._createJobForLabel(label1)
199
200        # job2's creation was later, it should be returned now.
201        self._do_heartbeat_and_assert_response(
202            known_hosts=[host1],
203            jobs=[job2], hqes=job2.hostqueueentry_set.all())
204
205        self._do_heartbeat_and_assert_response(
206            known_jobs=[job2], known_hosts=[host1])
207
208        job2 = models.Job.objects.get(pk=job2.pk)
209        job2.hostqueueentry_set.update(aborted=True)
210        # Setting a job to a complete status will set the shard_id to None in
211        # scheduler_models. We have to emulate that here, because we use Django
212        # models in tests.
213        job2.shard = None
214        job2.save()
215
216        self._do_heartbeat_and_assert_response(
217            known_jobs=[job2], known_hosts=[host1],
218            jobs=[job2],
219            hqes=job2.hostqueueentry_set.all())
220
221        models.Test.objects.create(name='platform_BootPerfServer:shard',
222                                   test_type=1)
223        self.mox.StubOutWithMock(server_utils, 'read_file')
224        self.mox.ReplayAll()
225        rpc_interface.delete_shard(hostname=shard1.hostname)
226
227        self.assertRaises(
228            models.Shard.DoesNotExist, models.Shard.objects.get, pk=shard1.id)
229
230        job1 = models.Job.objects.get(pk=job1.id)
231        label1 = models.Label.objects.get(pk=label1.id)
232
233        self.assertIsNone(job1.shard)
234        self.assertEqual(len(label1.shard_set.all()), 0)
235
236
237    def _testResendHostsAfterFailedHeartbeatHelper(self, host1):
238        """Check that master accepts resending updated records after failure."""
239        # Send the host
240        self._do_heartbeat_and_assert_response(hosts=[host1])
241
242        # Send it again because previous one didn't persist correctly
243        self._do_heartbeat_and_assert_response(hosts=[host1])
244
245        # Now it worked, make sure it isn't sent again
246        self._do_heartbeat_and_assert_response(known_hosts=[host1])
247
248
249class RpcInterfaceTestWithStaticAttribute(
250        mox.MoxTestBase, unittest.TestCase,
251        frontend_test_utils.FrontendTestMixin):
252
253    def setUp(self):
254        super(RpcInterfaceTestWithStaticAttribute, self).setUp()
255        self._frontend_common_setup()
256        self.god = mock.mock_god()
257        self.old_respect_static_config = rpc_interface.RESPECT_STATIC_ATTRIBUTES
258        rpc_interface.RESPECT_STATIC_ATTRIBUTES = True
259        models.RESPECT_STATIC_ATTRIBUTES = True
260
261
262    def tearDown(self):
263        self.god.unstub_all()
264        self._frontend_common_teardown()
265        global_config.global_config.reset_config_values()
266        rpc_interface.RESPECT_STATIC_ATTRIBUTES = self.old_respect_static_config
267        models.RESPECT_STATIC_ATTRIBUTES = self.old_respect_static_config
268
269
270    def _fake_host_with_static_attributes(self):
271        host1 = models.Host.objects.create(hostname='test_host')
272        host1.set_attribute('test_attribute1', 'test_value1')
273        host1.set_attribute('test_attribute2', 'test_value2')
274        self._set_static_attribute(host1, 'test_attribute1', 'static_value1')
275        self._set_static_attribute(host1, 'static_attribute1', 'static_value2')
276        host1.save()
277        return host1
278
279
280    def test_get_hosts(self):
281        host1 = self._fake_host_with_static_attributes()
282        hosts = rpc_interface.get_hosts(hostname=host1.hostname)
283        host = hosts[0]
284
285        self.assertEquals(host['hostname'], 'test_host')
286        self.assertEquals(host['acls'], ['Everyone'])
287        # Respect the value of static attributes.
288        self.assertEquals(host['attributes'],
289                          {'test_attribute1': 'static_value1',
290                           'test_attribute2': 'test_value2',
291                           'static_attribute1': 'static_value2'})
292
293    def test_get_host_attribute_with_static(self):
294        host1 = models.Host.objects.create(hostname='test_host1')
295        host1.set_attribute('test_attribute1', 'test_value1')
296        self._set_static_attribute(host1, 'test_attribute1', 'static_value1')
297        host2 = models.Host.objects.create(hostname='test_host2')
298        host2.set_attribute('test_attribute1', 'test_value1')
299        host2.set_attribute('test_attribute2', 'test_value2')
300
301        attributes = rpc_interface.get_host_attribute(
302                'test_attribute1',
303                hostname__in=['test_host1', 'test_host2'])
304        hosts = [attr['host'] for attr in attributes]
305        values = [attr['value'] for attr in attributes]
306        self.assertEquals(set(hosts),
307                          set(['test_host1', 'test_host2']))
308        self.assertEquals(set(values),
309                          set(['test_value1', 'static_value1']))
310
311
312    def test_get_hosts_by_attribute_without_static(self):
313        host1 = models.Host.objects.create(hostname='test_host1')
314        host1.set_attribute('test_attribute1', 'test_value1')
315        host2 = models.Host.objects.create(hostname='test_host2')
316        host2.set_attribute('test_attribute1', 'test_value1')
317
318        hosts = rpc_interface.get_hosts_by_attribute(
319                'test_attribute1', 'test_value1')
320        self.assertEquals(set(hosts),
321                          set(['test_host1', 'test_host2']))
322
323
324    def test_get_hosts_by_attribute_with_static(self):
325        host1 = models.Host.objects.create(hostname='test_host1')
326        host1.set_attribute('test_attribute1', 'test_value1')
327        self._set_static_attribute(host1, 'test_attribute1', 'test_value1')
328        host2 = models.Host.objects.create(hostname='test_host2')
329        host2.set_attribute('test_attribute1', 'test_value1')
330        self._set_static_attribute(host2, 'test_attribute1', 'static_value1')
331        host3 = models.Host.objects.create(hostname='test_host3')
332        self._set_static_attribute(host3, 'test_attribute1', 'test_value1')
333        host4 = models.Host.objects.create(hostname='test_host4')
334        host4.set_attribute('test_attribute1', 'test_value1')
335        host5 = models.Host.objects.create(hostname='test_host5')
336        host5.set_attribute('test_attribute1', 'temp_value1')
337        self._set_static_attribute(host5, 'test_attribute1', 'test_value1')
338
339        hosts = rpc_interface.get_hosts_by_attribute(
340                'test_attribute1', 'test_value1')
341        # host1: matched, it has the same value for test_attribute1.
342        # host2: not matched, it has a new value in
343        #        afe_static_host_attributes for test_attribute1.
344        # host3: matched, it has a corresponding entry in
345        #        afe_host_attributes for test_attribute1.
346        # host4: matched, test_attribute1 is not replaced by static
347        #        attribute.
348        # host5: matched, it has an updated & matched value for
349        #        test_attribute1 in afe_static_host_attributes.
350        self.assertEquals(set(hosts),
351                          set(['test_host1', 'test_host3',
352                               'test_host4', 'test_host5']))
353
354
355class RpcInterfaceTestWithStaticLabel(ShardHeartbeatTest,
356                                      frontend_test_utils.FrontendTestMixin):
357
358    _STATIC_LABELS = ['board:lumpy']
359
360    def setUp(self):
361        super(RpcInterfaceTestWithStaticLabel, self).setUp()
362        self._frontend_common_setup()
363        self.god = mock.mock_god()
364        self.old_respect_static_config = rpc_interface.RESPECT_STATIC_LABELS
365        rpc_interface.RESPECT_STATIC_LABELS = True
366        models.RESPECT_STATIC_LABELS = True
367
368
369    def tearDown(self):
370        self.god.unstub_all()
371        self._frontend_common_teardown()
372        global_config.global_config.reset_config_values()
373        rpc_interface.RESPECT_STATIC_LABELS = self.old_respect_static_config
374        models.RESPECT_STATIC_LABELS = self.old_respect_static_config
375
376
377    def _fake_host_with_static_labels(self):
378        host1 = models.Host.objects.create(hostname='test_host')
379        label1 = models.Label.objects.create(
380                name='non_static_label1', platform=False)
381        non_static_platform = models.Label.objects.create(
382                name='static_platform', platform=False)
383        static_platform = models.StaticLabel.objects.create(
384                name='static_platform', platform=True)
385        models.ReplacedLabel.objects.create(label_id=non_static_platform.id)
386        host1.static_labels.add(static_platform)
387        host1.labels.add(non_static_platform)
388        host1.labels.add(label1)
389        host1.save()
390        return host1
391
392
393    def test_get_hosts(self):
394        host1 = self._fake_host_with_static_labels()
395        hosts = rpc_interface.get_hosts(hostname=host1.hostname)
396        host = hosts[0]
397
398        self.assertEquals(host['hostname'], 'test_host')
399        self.assertEquals(host['acls'], ['Everyone'])
400        # Respect all labels in afe_hosts_labels.
401        self.assertEquals(host['labels'],
402                          ['non_static_label1', 'static_platform'])
403        # Respect static labels.
404        self.assertEquals(host['platform'], 'static_platform')
405
406
407    def test_get_hosts_multiple_labels(self):
408        self._fake_host_with_static_labels()
409        hosts = rpc_interface.get_hosts(
410                multiple_labels=['non_static_label1', 'static_platform'])
411        host = hosts[0]
412        self.assertEquals(host['hostname'], 'test_host')
413
414
415    def test_delete_static_label(self):
416        label1 = models.Label.smart_get('static')
417
418        host2 = models.Host.objects.all()[1]
419        shard1 = models.Shard.objects.create(hostname='shard1')
420        host2.shard = shard1
421        host2.labels.add(label1)
422        host2.save()
423
424        mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
425                                                  'MockAFE')
426        self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
427
428        self.assertRaises(error.UnmodifiableLabelException,
429                          rpc_interface.delete_label,
430                          label1.id)
431
432        self.god.check_playback()
433
434
435    def test_modify_static_label(self):
436        label1 = models.Label.smart_get('static')
437        self.assertEqual(label1.invalid, 0)
438
439        host2 = models.Host.objects.all()[1]
440        shard1 = models.Shard.objects.create(hostname='shard1')
441        host2.shard = shard1
442        host2.labels.add(label1)
443        host2.save()
444
445        mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
446                                                  'MockAFE')
447        self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
448
449        self.assertRaises(error.UnmodifiableLabelException,
450                          rpc_interface.modify_label,
451                          label1.id,
452                          invalid=1)
453
454        self.assertEqual(models.Label.smart_get('static').invalid, 0)
455        self.god.check_playback()
456
457
458    def test_multiple_platforms_add_non_static_to_static(self):
459        """Test non-static platform to a host with static platform."""
460        static_platform = models.StaticLabel.objects.create(
461                name='static_platform', platform=True)
462        non_static_platform = models.Label.objects.create(
463                name='static_platform', platform=True)
464        models.ReplacedLabel.objects.create(label_id=non_static_platform.id)
465        platform2 = models.Label.objects.create(name='platform2', platform=True)
466        host1 = models.Host.objects.create(hostname='test_host')
467        host1.static_labels.add(static_platform)
468        host1.labels.add(non_static_platform)
469        host1.save()
470
471        self.assertRaises(model_logic.ValidationError,
472                          rpc_interface.label_add_hosts, id='platform2',
473                          hosts=['test_host'])
474        self.assertRaises(model_logic.ValidationError,
475                          rpc_interface.host_add_labels,
476                          id='test_host', labels=['platform2'])
477        # make sure the platform didn't get added
478        platforms = rpc_interface.get_labels(
479            host__hostname__in=['test_host'], platform=True)
480        self.assertEquals(len(platforms), 1)
481
482
483    def test_multiple_platforms_add_static_to_non_static(self):
484        """Test static platform to a host with non-static platform."""
485        platform1 = models.Label.objects.create(
486                name='static_platform', platform=True)
487        models.ReplacedLabel.objects.create(label_id=platform1.id)
488        static_platform = models.StaticLabel.objects.create(
489                name='static_platform', platform=True)
490        platform2 = models.Label.objects.create(
491                name='platform2', platform=True)
492
493        host1 = models.Host.objects.create(hostname='test_host')
494        host1.labels.add(platform2)
495        host1.save()
496
497        self.assertRaises(model_logic.ValidationError,
498                          rpc_interface.label_add_hosts,
499                          id='static_platform',
500                          hosts=['test_host'])
501        self.assertRaises(model_logic.ValidationError,
502                          rpc_interface.host_add_labels,
503                          id='test_host', labels=['static_platform'])
504        # make sure the platform didn't get added
505        platforms = rpc_interface.get_labels(
506            host__hostname__in=['test_host'], platform=True)
507        self.assertEquals(len(platforms), 1)
508
509
510    def test_label_remove_hosts(self):
511        """Test remove a label of hosts."""
512        label = models.Label.smart_get('static')
513        static_label = models.StaticLabel.objects.create(name='static')
514
515        host1 = models.Host.objects.create(hostname='test_host')
516        host1.labels.add(label)
517        host1.static_labels.add(static_label)
518        host1.save()
519
520        self.assertRaises(error.UnmodifiableLabelException,
521                          rpc_interface.label_remove_hosts,
522                          id='static', hosts=['test_host'])
523
524
525    def test_host_remove_labels(self):
526        """Test remove labels of a given host."""
527        label = models.Label.smart_get('static')
528        label1 = models.Label.smart_get('label1')
529        label2 = models.Label.smart_get('label2')
530        static_label = models.StaticLabel.objects.create(name='static')
531
532        host1 = models.Host.objects.create(hostname='test_host')
533        host1.labels.add(label)
534        host1.labels.add(label1)
535        host1.labels.add(label2)
536        host1.static_labels.add(static_label)
537        host1.save()
538
539        rpc_interface.host_remove_labels(
540                'test_host', ['static', 'label1'])
541        labels = rpc_interface.get_labels(host__hostname__in=['test_host'])
542        # Only non_static label 'label1' is removed.
543        self.assertEquals(len(labels), 2)
544        self.assertEquals(labels[0].get('name'), 'label2')
545
546
547    def test_remove_board_from_shard(self):
548        """test remove a board (static label) from shard."""
549        label = models.Label.smart_get('static')
550        static_label = models.StaticLabel.objects.create(name='static')
551
552        shard = models.Shard.objects.create(hostname='test_shard')
553        shard.labels.add(label)
554
555        host = models.Host.objects.create(hostname='test_host',
556                                          leased=False,
557                                          shard=shard)
558        host.static_labels.add(static_label)
559        host.save()
560
561        rpc_interface.remove_board_from_shard(shard.hostname, label.name)
562        host1 = models.Host.smart_get(host.id)
563        shard1 = models.Shard.smart_get(shard.id)
564        self.assertEqual(host1.shard, None)
565        self.assertItemsEqual(shard1.labels.all(), [])
566
567
568    def test_check_job_dependencies_success(self):
569        """Test check_job_dependencies successfully."""
570        static_label = models.StaticLabel.objects.create(name='static')
571
572        host = models.Host.objects.create(hostname='test_host')
573        host.static_labels.add(static_label)
574        host.save()
575
576        host1 = models.Host.smart_get(host.id)
577        rpc_utils.check_job_dependencies([host1], ['static'])
578
579
580    def test_check_job_dependencies_fail(self):
581        """Test check_job_dependencies with raising ValidationError."""
582        label = models.Label.smart_get('static')
583        static_label = models.StaticLabel.objects.create(name='static')
584
585        host = models.Host.objects.create(hostname='test_host')
586        host.labels.add(label)
587        host.save()
588
589        host1 = models.Host.smart_get(host.id)
590        self.assertRaises(model_logic.ValidationError,
591                          rpc_utils.check_job_dependencies,
592                          [host1],
593                          ['static'])
594
595    def test_check_job_metahost_dependencies_success(self):
596        """Test check_job_metahost_dependencies successfully."""
597        label1 = models.Label.smart_get('label1')
598        label2 = models.Label.smart_get('label2')
599        label = models.Label.smart_get('static')
600        static_label = models.StaticLabel.objects.create(name='static')
601
602        host = models.Host.objects.create(hostname='test_host')
603        host.static_labels.add(static_label)
604        host.labels.add(label1)
605        host.labels.add(label2)
606        host.save()
607
608        rpc_utils.check_job_metahost_dependencies(
609                [label1, label], [label2.name])
610        rpc_utils.check_job_metahost_dependencies(
611                [label1], [label2.name, static_label.name])
612
613
614    def test_check_job_metahost_dependencies_fail(self):
615        """Test check_job_metahost_dependencies with raising errors."""
616        label1 = models.Label.smart_get('label1')
617        label2 = models.Label.smart_get('label2')
618        label = models.Label.smart_get('static')
619        static_label = models.StaticLabel.objects.create(name='static')
620
621        host = models.Host.objects.create(hostname='test_host')
622        host.labels.add(label1)
623        host.labels.add(label2)
624        host.save()
625
626        self.assertRaises(error.NoEligibleHostException,
627                          rpc_utils.check_job_metahost_dependencies,
628                          [label1, label], [label2.name])
629        self.assertRaises(error.NoEligibleHostException,
630                          rpc_utils.check_job_metahost_dependencies,
631                          [label1], [label2.name, static_label.name])
632
633
634    def _createShardAndHostWithStaticLabel(self,
635                                           shard_hostname='shard1',
636                                           host_hostname='test_host1',
637                                           label_name='board:lumpy'):
638        label = models.Label.objects.create(name=label_name)
639
640        shard = models.Shard.objects.create(hostname=shard_hostname)
641        shard.labels.add(label)
642
643        host = models.Host.objects.create(hostname=host_hostname, leased=False,
644                                          shard=shard)
645        host.labels.add(label)
646        if label_name in self._STATIC_LABELS:
647            models.ReplacedLabel.objects.create(label_id=label.id)
648            static_label = models.StaticLabel.objects.create(name=label_name)
649            host.static_labels.add(static_label)
650
651        return shard, host, label
652
653
654    def testShardHeartbeatFetchHostlessJob(self):
655        shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
656                host_hostname='test_host1')
657        self._testShardHeartbeatFetchHostlessJobHelper(host1)
658
659
660    def testShardHeartbeatIncorrectHosts(self):
661        shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
662                host_hostname='test_host1')
663        self._testShardHeartbeatIncorrectHostsHelper(host1)
664
665
666    def testShardHeartbeatLabelRemovalRace(self):
667        shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
668                host_hostname='test_host1')
669        self._testShardHeartbeatLabelRemovalRaceHelper(shard1, host1, label1)
670
671
672    def testShardRetrieveJobs(self):
673        shard1, host1, label1 = self._createShardAndHostWithStaticLabel()
674        shard2, host2, label2 = self._createShardAndHostWithStaticLabel(
675            'shard2', 'test_host2', 'board:grumpy')
676        self._testShardRetrieveJobsHelper(shard1, host1, label1,
677                                          shard2, host2, label2)
678
679
680    def testResendJobsAfterFailedHeartbeat(self):
681        shard1, host1, label1 = self._createShardAndHostWithStaticLabel()
682        self._testResendJobsAfterFailedHeartbeatHelper(shard1, host1, label1)
683
684
685    def testResendHostsAfterFailedHeartbeat(self):
686        shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
687                host_hostname='test_host1')
688        self._testResendHostsAfterFailedHeartbeatHelper(host1)
689
690
691class RpcInterfaceTest(unittest.TestCase,
692                       frontend_test_utils.FrontendTestMixin):
693    def setUp(self):
694        self._frontend_common_setup()
695        self.god = mock.mock_god()
696
697
698    def tearDown(self):
699        self.god.unstub_all()
700        self._frontend_common_teardown()
701        global_config.global_config.reset_config_values()
702
703
704    def test_validation(self):
705        # omit a required field
706        self.assertRaises(model_logic.ValidationError, rpc_interface.add_label,
707                          name=None)
708        # violate uniqueness constraint
709        self.assertRaises(model_logic.ValidationError, rpc_interface.add_host,
710                          hostname='host1')
711
712
713    def test_multiple_platforms(self):
714        platform2 = models.Label.objects.create(name='platform2', platform=True)
715        self.assertRaises(model_logic.ValidationError,
716                          rpc_interface. label_add_hosts, id='platform2',
717                          hosts=['host1', 'host2'])
718        self.assertRaises(model_logic.ValidationError,
719                          rpc_interface.host_add_labels,
720                          id='host1', labels=['platform2'])
721        # make sure the platform didn't get added
722        platforms = rpc_interface.get_labels(
723            host__hostname__in=['host1', 'host2'], platform=True)
724        self.assertEquals(len(platforms), 1)
725        self.assertEquals(platforms[0]['name'], 'myplatform')
726
727
728    def _check_hostnames(self, hosts, expected_hostnames):
729        self.assertEquals(set(host['hostname'] for host in hosts),
730                          set(expected_hostnames))
731
732
733    def test_ping_db(self):
734        self.assertEquals(rpc_interface.ping_db(), [True])
735
736
737    def test_get_hosts_by_attribute(self):
738        host1 = models.Host.objects.create(hostname='test_host1')
739        host1.set_attribute('test_attribute1', 'test_value1')
740        host2 = models.Host.objects.create(hostname='test_host2')
741        host2.set_attribute('test_attribute1', 'test_value1')
742
743        hosts = rpc_interface.get_hosts_by_attribute(
744                'test_attribute1', 'test_value1')
745        self.assertEquals(set(hosts),
746                          set(['test_host1', 'test_host2']))
747
748
749    def test_get_host_attribute(self):
750        host1 = models.Host.objects.create(hostname='test_host1')
751        host1.set_attribute('test_attribute1', 'test_value1')
752        host2 = models.Host.objects.create(hostname='test_host2')
753        host2.set_attribute('test_attribute1', 'test_value1')
754
755        attributes = rpc_interface.get_host_attribute(
756                'test_attribute1',
757                hostname__in=['test_host1', 'test_host2'])
758        hosts = [attr['host'] for attr in attributes]
759        values = [attr['value'] for attr in attributes]
760        self.assertEquals(set(hosts),
761                          set(['test_host1', 'test_host2']))
762        self.assertEquals(set(values), set(['test_value1']))
763
764
765    def test_get_hosts(self):
766        hosts = rpc_interface.get_hosts()
767        self._check_hostnames(hosts, [host.hostname for host in self.hosts])
768
769        hosts = rpc_interface.get_hosts(hostname='host1')
770        self._check_hostnames(hosts, ['host1'])
771        host = hosts[0]
772        self.assertEquals(sorted(host['labels']), ['label1', 'myplatform'])
773        self.assertEquals(host['platform'], 'myplatform')
774        self.assertEquals(host['acls'], ['my_acl'])
775        self.assertEquals(host['attributes'], {})
776
777
778    def test_get_hosts_multiple_labels(self):
779        hosts = rpc_interface.get_hosts(
780                multiple_labels=['myplatform', 'label1'])
781        self._check_hostnames(hosts, ['host1'])
782
783
784    def test_job_keyvals(self):
785        keyval_dict = {'mykey': 'myvalue'}
786        job_id = rpc_interface.create_job(name='test',
787                                          priority=priorities.Priority.DEFAULT,
788                                          control_file='foo',
789                                          control_type=CLIENT,
790                                          hosts=['host1'],
791                                          keyvals=keyval_dict)
792        jobs = rpc_interface.get_jobs(id=job_id)
793        self.assertEquals(len(jobs), 1)
794        self.assertEquals(jobs[0]['keyvals'], keyval_dict)
795
796
797    def test_test_retry(self):
798        job_id = rpc_interface.create_job(name='flake',
799                                          priority=priorities.Priority.DEFAULT,
800                                          control_file='foo',
801                                          control_type=CLIENT,
802                                          hosts=['host1'],
803                                          test_retry=10)
804        jobs = rpc_interface.get_jobs(id=job_id)
805        self.assertEquals(len(jobs), 1)
806        self.assertEquals(jobs[0]['test_retry'], 10)
807
808
809    def test_get_jobs_summary(self):
810        job = self._create_job(hosts=xrange(1, 4))
811        entries = list(job.hostqueueentry_set.all())
812        entries[1].status = _hqe_status.FAILED
813        entries[1].save()
814        entries[2].status = _hqe_status.FAILED
815        entries[2].aborted = True
816        entries[2].save()
817
818        # Mock up tko_rpc_interface.get_status_counts.
819        self.god.stub_function_to_return(rpc_interface.tko_rpc_interface,
820                                         'get_status_counts',
821                                         None)
822
823        job_summaries = rpc_interface.get_jobs_summary(id=job.id)
824        self.assertEquals(len(job_summaries), 1)
825        summary = job_summaries[0]
826        self.assertEquals(summary['status_counts'], {'Queued': 1,
827                                                     'Failed': 2})
828
829
830    def _check_job_ids(self, actual_job_dicts, expected_jobs):
831        self.assertEquals(
832                set(job_dict['id'] for job_dict in actual_job_dicts),
833                set(job.id for job in expected_jobs))
834
835
836    def test_get_jobs_status_filters(self):
837        HqeStatus = models.HostQueueEntry.Status
838        def create_two_host_job():
839            return self._create_job(hosts=[1, 2])
840        def set_hqe_statuses(job, first_status, second_status):
841            entries = job.hostqueueentry_set.all()
842            entries[0].update_object(status=first_status)
843            entries[1].update_object(status=second_status)
844
845        queued = create_two_host_job()
846
847        queued_and_running = create_two_host_job()
848        set_hqe_statuses(queued_and_running, HqeStatus.QUEUED,
849                           HqeStatus.RUNNING)
850
851        running_and_complete = create_two_host_job()
852        set_hqe_statuses(running_and_complete, HqeStatus.RUNNING,
853                           HqeStatus.COMPLETED)
854
855        complete = create_two_host_job()
856        set_hqe_statuses(complete, HqeStatus.COMPLETED, HqeStatus.COMPLETED)
857
858        started_but_inactive = create_two_host_job()
859        set_hqe_statuses(started_but_inactive, HqeStatus.QUEUED,
860                           HqeStatus.COMPLETED)
861
862        parsing = create_two_host_job()
863        set_hqe_statuses(parsing, HqeStatus.PARSING, HqeStatus.PARSING)
864
865        self._check_job_ids(rpc_interface.get_jobs(not_yet_run=True), [queued])
866        self._check_job_ids(rpc_interface.get_jobs(running=True),
867                      [queued_and_running, running_and_complete,
868                       started_but_inactive, parsing])
869        self._check_job_ids(rpc_interface.get_jobs(finished=True), [complete])
870
871
872    def test_get_jobs_type_filters(self):
873        self.assertRaises(AssertionError, rpc_interface.get_jobs,
874                          suite=True, sub=True)
875        self.assertRaises(AssertionError, rpc_interface.get_jobs,
876                          suite=True, standalone=True)
877        self.assertRaises(AssertionError, rpc_interface.get_jobs,
878                          standalone=True, sub=True)
879
880        parent_job = self._create_job(hosts=[1])
881        child_jobs = self._create_job(hosts=[1, 2],
882                                      parent_job_id=parent_job.id)
883        standalone_job = self._create_job(hosts=[1])
884
885        self._check_job_ids(rpc_interface.get_jobs(suite=True), [parent_job])
886        self._check_job_ids(rpc_interface.get_jobs(sub=True), [child_jobs])
887        self._check_job_ids(rpc_interface.get_jobs(standalone=True),
888                            [standalone_job])
889
890
891    def _create_job_helper(self, **kwargs):
892        return rpc_interface.create_job(name='test',
893                                        priority=priorities.Priority.DEFAULT,
894                                        control_file='control file',
895                                        control_type=SERVER, **kwargs)
896
897
898    def test_one_time_hosts(self):
899        job = self._create_job_helper(one_time_hosts=['testhost'])
900        host = models.Host.objects.get(hostname='testhost')
901        self.assertEquals(host.invalid, True)
902        self.assertEquals(host.labels.count(), 0)
903        self.assertEquals(host.aclgroup_set.count(), 0)
904
905
906    def test_create_job_duplicate_hosts(self):
907        self.assertRaises(model_logic.ValidationError, self._create_job_helper,
908                          hosts=[1, 1])
909
910
911    def test_create_unrunnable_metahost_job(self):
912        self.assertRaises(error.NoEligibleHostException,
913                          self._create_job_helper, meta_hosts=['unused'])
914
915
916    def test_create_hostless_job(self):
917        job_id = self._create_job_helper(hostless=True)
918        job = models.Job.objects.get(pk=job_id)
919        queue_entries = job.hostqueueentry_set.all()
920        self.assertEquals(len(queue_entries), 1)
921        self.assertEquals(queue_entries[0].host, None)
922        self.assertEquals(queue_entries[0].meta_host, None)
923
924
925    def _setup_special_tasks(self):
926        host = self.hosts[0]
927
928        job1 = self._create_job(hosts=[1])
929        job2 = self._create_job(hosts=[1])
930
931        entry1 = job1.hostqueueentry_set.all()[0]
932        entry1.update_object(started_on=datetime.datetime(2009, 1, 2),
933                             execution_subdir='host1')
934        entry2 = job2.hostqueueentry_set.all()[0]
935        entry2.update_object(started_on=datetime.datetime(2009, 1, 3),
936                             execution_subdir='host1')
937
938        self.task1 = models.SpecialTask.objects.create(
939                host=host, task=models.SpecialTask.Task.VERIFY,
940                time_started=datetime.datetime(2009, 1, 1), # ran before job 1
941                is_complete=True, requested_by=models.User.current_user())
942        self.task2 = models.SpecialTask.objects.create(
943                host=host, task=models.SpecialTask.Task.VERIFY,
944                queue_entry=entry2, # ran with job 2
945                is_active=True, requested_by=models.User.current_user())
946        self.task3 = models.SpecialTask.objects.create(
947                host=host, task=models.SpecialTask.Task.VERIFY,
948                requested_by=models.User.current_user()) # not yet run
949
950
951    def test_get_special_tasks(self):
952        self._setup_special_tasks()
953        tasks = rpc_interface.get_special_tasks(host__hostname='host1',
954                                                queue_entry__isnull=True)
955        self.assertEquals(len(tasks), 2)
956        self.assertEquals(tasks[0]['task'], models.SpecialTask.Task.VERIFY)
957        self.assertEquals(tasks[0]['is_active'], False)
958        self.assertEquals(tasks[0]['is_complete'], True)
959
960
961    def test_get_latest_special_task(self):
962        # a particular usage of get_special_tasks()
963        self._setup_special_tasks()
964        self.task2.time_started = datetime.datetime(2009, 1, 2)
965        self.task2.save()
966
967        tasks = rpc_interface.get_special_tasks(
968                host__hostname='host1', task=models.SpecialTask.Task.VERIFY,
969                time_started__isnull=False, sort_by=['-time_started'],
970                query_limit=1)
971        self.assertEquals(len(tasks), 1)
972        self.assertEquals(tasks[0]['id'], 2)
973
974
975    def _common_entry_check(self, entry_dict):
976        self.assertEquals(entry_dict['host']['hostname'], 'host1')
977        self.assertEquals(entry_dict['job']['id'], 2)
978
979
980    def test_get_host_queue_entries_and_special_tasks(self):
981        self._setup_special_tasks()
982
983        host = self.hosts[0].id
984        entries_and_tasks = (
985                rpc_interface.get_host_queue_entries_and_special_tasks(host))
986
987        paths = [entry['execution_path'] for entry in entries_and_tasks]
988        self.assertEquals(paths, ['hosts/host1/3-verify',
989                                  '2-autotest_system/host1',
990                                  'hosts/host1/2-verify',
991                                  '1-autotest_system/host1',
992                                  'hosts/host1/1-verify'])
993
994        verify2 = entries_and_tasks[2]
995        self._common_entry_check(verify2)
996        self.assertEquals(verify2['type'], 'Verify')
997        self.assertEquals(verify2['status'], 'Running')
998        self.assertEquals(verify2['execution_path'], 'hosts/host1/2-verify')
999
1000        entry2 = entries_and_tasks[1]
1001        self._common_entry_check(entry2)
1002        self.assertEquals(entry2['type'], 'Job')
1003        self.assertEquals(entry2['status'], 'Queued')
1004        self.assertEquals(entry2['started_on'], '2009-01-03 00:00:00')
1005
1006
1007    def _create_hqes_and_start_time_index_entries(self):
1008        shard = models.Shard.objects.create(hostname='shard')
1009        job = self._create_job(shard=shard, control_file='foo')
1010        HqeStatus = models.HostQueueEntry.Status
1011
1012        models.HostQueueEntry(
1013            id=1, job=job, started_on='2017-01-01',
1014            status=HqeStatus.QUEUED).save()
1015        models.HostQueueEntry(
1016            id=2, job=job, started_on='2017-01-02',
1017            status=HqeStatus.QUEUED).save()
1018        models.HostQueueEntry(
1019            id=3, job=job, started_on='2017-01-03',
1020            status=HqeStatus.QUEUED).save()
1021
1022        models.HostQueueEntryStartTimes(
1023            insert_time='2017-01-03', highest_hqe_id=3).save()
1024        models.HostQueueEntryStartTimes(
1025            insert_time='2017-01-02', highest_hqe_id=2).save()
1026        models.HostQueueEntryStartTimes(
1027            insert_time='2017-01-01', highest_hqe_id=1).save()
1028
1029    def test_get_host_queue_entries_by_insert_time(self):
1030        """Check the insert_time_after and insert_time_before constraints."""
1031        self._create_hqes_and_start_time_index_entries()
1032        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1033            insert_time_after='2017-01-01')
1034        self.assertEquals(len(hqes), 3)
1035
1036        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1037            insert_time_after='2017-01-02')
1038        self.assertEquals(len(hqes), 2)
1039
1040        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1041            insert_time_after='2017-01-03')
1042        self.assertEquals(len(hqes), 1)
1043
1044        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1045            insert_time_before='2017-01-01')
1046        self.assertEquals(len(hqes), 1)
1047
1048        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1049            insert_time_before='2017-01-02')
1050        self.assertEquals(len(hqes), 2)
1051
1052        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1053            insert_time_before='2017-01-03')
1054        self.assertEquals(len(hqes), 3)
1055
1056
1057    def test_get_host_queue_entries_by_insert_time_with_missing_index_row(self):
1058        """Shows that the constraints are approximate.
1059
1060        The query may return rows which are actually outside of the bounds
1061        given, if the index table does not have an entry for the specific time.
1062        """
1063        self._create_hqes_and_start_time_index_entries()
1064        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1065            insert_time_before='2016-12-01')
1066        self.assertEquals(len(hqes), 1)
1067
1068    def test_get_hqe_by_insert_time_with_before_and_after(self):
1069        self._create_hqes_and_start_time_index_entries()
1070        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1071            insert_time_before='2017-01-02',
1072            insert_time_after='2017-01-02')
1073        self.assertEquals(len(hqes), 1)
1074
1075    def test_get_hqe_by_insert_time_and_id_constraint(self):
1076        self._create_hqes_and_start_time_index_entries()
1077        # The time constraint is looser than the id constraint, so the time
1078        # constraint should take precedence.
1079        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1080            insert_time_before='2017-01-02',
1081            id__lte=1)
1082        self.assertEquals(len(hqes), 1)
1083
1084        # Now make the time constraint tighter than the id constraint.
1085        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1086            insert_time_before='2017-01-01',
1087            id__lte=42)
1088        self.assertEquals(len(hqes), 1)
1089
1090    def test_view_invalid_host(self):
1091        # RPCs used by View Host page should work for invalid hosts
1092        self._create_job_helper(hosts=[1])
1093        host = self.hosts[0]
1094        host.delete()
1095
1096        self.assertEquals(1, rpc_interface.get_num_hosts(hostname='host1',
1097                                                         valid_only=False))
1098        data = rpc_interface.get_hosts(hostname='host1', valid_only=False)
1099        self.assertEquals(1, len(data))
1100
1101        self.assertEquals(1, rpc_interface.get_num_host_queue_entries(
1102                host__hostname='host1'))
1103        data = rpc_interface.get_host_queue_entries(host__hostname='host1')
1104        self.assertEquals(1, len(data))
1105
1106        count = rpc_interface.get_num_host_queue_entries_and_special_tasks(
1107                host=host.id)
1108        self.assertEquals(1, count)
1109        data = rpc_interface.get_host_queue_entries_and_special_tasks(
1110                host=host.id)
1111        self.assertEquals(1, len(data))
1112
1113
1114    def test_reverify_hosts(self):
1115        hostname_list = rpc_interface.reverify_hosts(id__in=[1, 2])
1116        self.assertEquals(hostname_list, ['host1', 'host2'])
1117        tasks = rpc_interface.get_special_tasks()
1118        self.assertEquals(len(tasks), 2)
1119        self.assertEquals(set(task['host']['id'] for task in tasks),
1120                          set([1, 2]))
1121
1122        task = tasks[0]
1123        self.assertEquals(task['task'], models.SpecialTask.Task.VERIFY)
1124        self.assertEquals(task['requested_by'], 'autotest_system')
1125
1126
1127    def test_repair_hosts(self):
1128        hostname_list = rpc_interface.repair_hosts(id__in=[1, 2])
1129        self.assertEquals(hostname_list, ['host1', 'host2'])
1130        tasks = rpc_interface.get_special_tasks()
1131        self.assertEquals(len(tasks), 2)
1132        self.assertEquals(set(task['host']['id'] for task in tasks),
1133                          set([1, 2]))
1134
1135        task = tasks[0]
1136        self.assertEquals(task['task'], models.SpecialTask.Task.REPAIR)
1137        self.assertEquals(task['requested_by'], 'autotest_system')
1138
1139
1140    def _modify_host_helper(self, on_shard=False, host_on_shard=False):
1141        shard_hostname = 'shard1'
1142        if on_shard:
1143            global_config.global_config.override_config_value(
1144                'SHARD', 'shard_hostname', shard_hostname)
1145
1146        host = models.Host.objects.all()[0]
1147        if host_on_shard:
1148            shard = models.Shard.objects.create(hostname=shard_hostname)
1149            host.shard = shard
1150            host.save()
1151
1152        self.assertFalse(host.locked)
1153
1154        self.god.stub_class_method(frontend.AFE, 'run')
1155
1156        if host_on_shard and not on_shard:
1157            mock_afe = self.god.create_mock_class_obj(
1158                    frontend_wrappers.RetryingAFE, 'MockAFE')
1159            self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
1160
1161            mock_afe2 = frontend_wrappers.RetryingAFE.expect_new(
1162                    server=shard_hostname, user=None)
1163            mock_afe2.run.expect_call('modify_host_local', id=host.id,
1164                    locked=True, lock_reason='_modify_host_helper lock',
1165                    lock_time=datetime.datetime(2015, 12, 15))
1166        elif on_shard:
1167            mock_afe = self.god.create_mock_class_obj(
1168                    frontend_wrappers.RetryingAFE, 'MockAFE')
1169            self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
1170
1171            mock_afe2 = frontend_wrappers.RetryingAFE.expect_new(
1172                    server=server_utils.get_global_afe_hostname(), user=None)
1173            mock_afe2.run.expect_call('modify_host', id=host.id,
1174                    locked=True, lock_reason='_modify_host_helper lock',
1175                    lock_time=datetime.datetime(2015, 12, 15))
1176
1177        rpc_interface.modify_host(id=host.id, locked=True,
1178                                  lock_reason='_modify_host_helper lock',
1179                                  lock_time=datetime.datetime(2015, 12, 15))
1180
1181        host = models.Host.objects.get(pk=host.id)
1182        if on_shard:
1183            # modify_host on shard does nothing but routing the RPC to master.
1184            self.assertFalse(host.locked)
1185        else:
1186            self.assertTrue(host.locked)
1187        self.god.check_playback()
1188
1189
1190    def test_modify_host_on_master_host_on_master(self):
1191        """Call modify_host to master for host in master."""
1192        self._modify_host_helper()
1193
1194
1195    def test_modify_host_on_master_host_on_shard(self):
1196        """Call modify_host to master for host in shard."""
1197        self._modify_host_helper(host_on_shard=True)
1198
1199
1200    def test_modify_host_on_shard(self):
1201        """Call modify_host to shard for host in shard."""
1202        self._modify_host_helper(on_shard=True, host_on_shard=True)
1203
1204
1205    def test_modify_hosts_on_master_host_on_shard(self):
1206        """Ensure calls to modify_hosts are correctly forwarded to shards."""
1207        host1 = models.Host.objects.all()[0]
1208        host2 = models.Host.objects.all()[1]
1209
1210        shard1 = models.Shard.objects.create(hostname='shard1')
1211        host1.shard = shard1
1212        host1.save()
1213
1214        shard2 = models.Shard.objects.create(hostname='shard2')
1215        host2.shard = shard2
1216        host2.save()
1217
1218        self.assertFalse(host1.locked)
1219        self.assertFalse(host2.locked)
1220
1221        mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
1222                                                  'MockAFE')
1223        self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
1224
1225        # The statuses of one host might differ on master and shard.
1226        # Filters are always applied on the master. So the host on the shard
1227        # will be affected no matter what his status is.
1228        filters_to_use = {'status': 'Ready'}
1229
1230        mock_afe2 = frontend_wrappers.RetryingAFE.expect_new(
1231                server='shard2', user=None)
1232        mock_afe2.run.expect_call(
1233            'modify_hosts_local',
1234            host_filter_data={'id__in': [shard1.id, shard2.id]},
1235            update_data={'locked': True,
1236                         'lock_reason': 'Testing forward to shard',
1237                         'lock_time' : datetime.datetime(2015, 12, 15) })
1238
1239        mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
1240                server='shard1', user=None)
1241        mock_afe1.run.expect_call(
1242            'modify_hosts_local',
1243            host_filter_data={'id__in': [shard1.id, shard2.id]},
1244            update_data={'locked': True,
1245                         'lock_reason': 'Testing forward to shard',
1246                         'lock_time' : datetime.datetime(2015, 12, 15)})
1247
1248        rpc_interface.modify_hosts(
1249                host_filter_data={'status': 'Ready'},
1250                update_data={'locked': True,
1251                             'lock_reason': 'Testing forward to shard',
1252                             'lock_time' : datetime.datetime(2015, 12, 15) })
1253
1254        host1 = models.Host.objects.get(pk=host1.id)
1255        self.assertTrue(host1.locked)
1256        host2 = models.Host.objects.get(pk=host2.id)
1257        self.assertTrue(host2.locked)
1258        self.god.check_playback()
1259
1260
1261    def test_delete_host(self):
1262        """Ensure an RPC is made on delete a host, if it is on a shard."""
1263        host1 = models.Host.objects.all()[0]
1264        shard1 = models.Shard.objects.create(hostname='shard1')
1265        host1.shard = shard1
1266        host1.save()
1267        host1_id = host1.id
1268
1269        mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
1270                                                 'MockAFE')
1271        self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
1272
1273        mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
1274                server='shard1', user=None)
1275        mock_afe1.run.expect_call('delete_host', id=host1.id)
1276
1277        rpc_interface.delete_host(id=host1.id)
1278
1279        self.assertRaises(models.Host.DoesNotExist,
1280                          models.Host.smart_get, host1_id)
1281
1282        self.god.check_playback()
1283
1284
1285    def test_modify_label(self):
1286        label1 = models.Label.objects.all()[0]
1287        self.assertEqual(label1.invalid, 0)
1288
1289        host2 = models.Host.objects.all()[1]
1290        shard1 = models.Shard.objects.create(hostname='shard1')
1291        host2.shard = shard1
1292        host2.labels.add(label1)
1293        host2.save()
1294
1295        mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
1296                                                  'MockAFE')
1297        self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
1298
1299        mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
1300                server='shard1', user=None)
1301        mock_afe1.run.expect_call('modify_label', id=label1.id, invalid=1)
1302
1303        rpc_interface.modify_label(label1.id, invalid=1)
1304
1305        self.assertEqual(models.Label.objects.all()[0].invalid, 1)
1306        self.god.check_playback()
1307
1308
1309    def test_delete_label(self):
1310        label1 = models.Label.objects.all()[0]
1311
1312        host2 = models.Host.objects.all()[1]
1313        shard1 = models.Shard.objects.create(hostname='shard1')
1314        host2.shard = shard1
1315        host2.labels.add(label1)
1316        host2.save()
1317
1318        mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
1319                                                  'MockAFE')
1320        self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
1321
1322        mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
1323                server='shard1', user=None)
1324        mock_afe1.run.expect_call('delete_label', id=label1.id)
1325
1326        rpc_interface.delete_label(id=label1.id)
1327
1328        self.assertRaises(models.Label.DoesNotExist,
1329                          models.Label.smart_get, label1.id)
1330        self.god.check_playback()
1331
1332
1333    def test_get_image_for_job_with_keyval_build(self):
1334        keyval_dict = {'build': 'cool-image'}
1335        job_id = rpc_interface.create_job(name='test',
1336                                          priority=priorities.Priority.DEFAULT,
1337                                          control_file='foo',
1338                                          control_type=CLIENT,
1339                                          hosts=['host1'],
1340                                          keyvals=keyval_dict)
1341        job = models.Job.objects.get(id=job_id)
1342        self.assertIsNotNone(job)
1343        image = rpc_interface._get_image_for_job(job, True)
1344        self.assertEquals('cool-image', image)
1345
1346
1347    def test_get_image_for_job_with_keyval_builds(self):
1348        keyval_dict = {'builds': {'cros-version': 'cool-image'}}
1349        job_id = rpc_interface.create_job(name='test',
1350                                          priority=priorities.Priority.DEFAULT,
1351                                          control_file='foo',
1352                                          control_type=CLIENT,
1353                                          hosts=['host1'],
1354                                          keyvals=keyval_dict)
1355        job = models.Job.objects.get(id=job_id)
1356        self.assertIsNotNone(job)
1357        image = rpc_interface._get_image_for_job(job, True)
1358        self.assertEquals('cool-image', image)
1359
1360
1361    def test_get_image_for_job_with_control_build(self):
1362        CONTROL_FILE = """build='cool-image'
1363        """
1364        job_id = rpc_interface.create_job(name='test',
1365                                          priority=priorities.Priority.DEFAULT,
1366                                          control_file='foo',
1367                                          control_type=CLIENT,
1368                                          hosts=['host1'])
1369        job = models.Job.objects.get(id=job_id)
1370        self.assertIsNotNone(job)
1371        job.control_file = CONTROL_FILE
1372        image = rpc_interface._get_image_for_job(job, True)
1373        self.assertEquals('cool-image', image)
1374
1375
1376    def test_get_image_for_job_with_control_builds(self):
1377        CONTROL_FILE = """builds={'cros-version': 'cool-image'}
1378        """
1379        job_id = rpc_interface.create_job(name='test',
1380                                          priority=priorities.Priority.DEFAULT,
1381                                          control_file='foo',
1382                                          control_type=CLIENT,
1383                                          hosts=['host1'])
1384        job = models.Job.objects.get(id=job_id)
1385        self.assertIsNotNone(job)
1386        job.control_file = CONTROL_FILE
1387        image = rpc_interface._get_image_for_job(job, True)
1388        self.assertEquals('cool-image', image)
1389
1390
1391class ExtraRpcInterfaceTest(frontend_test_utils.FrontendTestMixin,
1392                            ShardHeartbeatTest):
1393    """Unit tests for functions originally in site_rpc_interface.py.
1394
1395    @var _NAME: fake suite name.
1396    @var _BOARD: fake board to reimage.
1397    @var _BUILD: fake build with which to reimage.
1398    @var _PRIORITY: fake priority with which to reimage.
1399    """
1400    _NAME = 'name'
1401    _BOARD = 'link'
1402    _BUILD = 'link-release/R36-5812.0.0'
1403    _BUILDS = {provision.CROS_VERSION_PREFIX: _BUILD}
1404    _PRIORITY = priorities.Priority.DEFAULT
1405    _TIMEOUT = 24
1406
1407
1408    def setUp(self):
1409        super(ExtraRpcInterfaceTest, self).setUp()
1410        self._SUITE_NAME = rpc_interface.canonicalize_suite_name(
1411            self._NAME)
1412        self.dev_server = self.mox.CreateMock(dev_server.ImageServer)
1413        self._frontend_common_setup(fill_data=False)
1414
1415
1416    def tearDown(self):
1417        self._frontend_common_teardown()
1418
1419
1420    def _setupDevserver(self):
1421        self.mox.StubOutClassWithMocks(dev_server, 'ImageServer')
1422        dev_server.resolve(self._BUILD).AndReturn(self.dev_server)
1423
1424
1425    def _mockDevServerGetter(self, get_control_file=True):
1426        self._setupDevserver()
1427        if get_control_file:
1428          self.getter = self.mox.CreateMock(
1429              control_file_getter.DevServerGetter)
1430          self.mox.StubOutWithMock(control_file_getter.DevServerGetter,
1431                                   'create')
1432          control_file_getter.DevServerGetter.create(
1433              mox.IgnoreArg(), mox.IgnoreArg()).AndReturn(self.getter)
1434
1435
1436    def _mockRpcUtils(self, to_return, control_file_substring=''):
1437        """Fake out the autotest rpc_utils module with a mockable class.
1438
1439        @param to_return: the value that rpc_utils.create_job_common() should
1440                          be mocked out to return.
1441        @param control_file_substring: A substring that is expected to appear
1442                                       in the control file output string that
1443                                       is passed to create_job_common.
1444                                       Default: ''
1445        """
1446        download_started_time = constants.DOWNLOAD_STARTED_TIME
1447        payload_finished_time = constants.PAYLOAD_FINISHED_TIME
1448        self.mox.StubOutWithMock(rpc_utils, 'create_job_common')
1449        rpc_utils.create_job_common(mox.And(mox.StrContains(self._NAME),
1450                                    mox.StrContains(self._BUILD)),
1451                            priority=self._PRIORITY,
1452                            timeout_mins=self._TIMEOUT*60,
1453                            max_runtime_mins=self._TIMEOUT*60,
1454                            control_type='Server',
1455                            control_file=mox.And(mox.StrContains(self._BOARD),
1456                                                 mox.StrContains(self._BUILD),
1457                                                 mox.StrContains(
1458                                                     control_file_substring)),
1459                            hostless=True,
1460                            keyvals=mox.And(mox.In(download_started_time),
1461                                            mox.In(payload_finished_time))
1462                            ).AndReturn(to_return)
1463
1464
1465    def testStageBuildFail(self):
1466        """Ensure that a failure to stage the desired build fails the RPC."""
1467        self._setupDevserver()
1468
1469        self.dev_server.hostname = 'mox_url'
1470        self.dev_server.stage_artifacts(
1471                image=self._BUILD, artifacts=['test_suites']).AndRaise(
1472                dev_server.DevServerException())
1473        self.mox.ReplayAll()
1474        self.assertRaises(error.StageControlFileFailure,
1475                          rpc_interface.create_suite_job,
1476                          name=self._NAME,
1477                          board=self._BOARD,
1478                          builds=self._BUILDS,
1479                          pool=None)
1480
1481
1482    def testGetControlFileFail(self):
1483        """Ensure that a failure to get needed control file fails the RPC."""
1484        self._mockDevServerGetter()
1485
1486        self.dev_server.hostname = 'mox_url'
1487        self.dev_server.stage_artifacts(
1488                image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
1489
1490        self.getter.get_control_file_contents_by_name(
1491            self._SUITE_NAME).AndReturn(None)
1492        self.mox.ReplayAll()
1493        self.assertRaises(error.ControlFileEmpty,
1494                          rpc_interface.create_suite_job,
1495                          name=self._NAME,
1496                          board=self._BOARD,
1497                          builds=self._BUILDS,
1498                          pool=None)
1499
1500
1501    def testGetControlFileListFail(self):
1502        """Ensure that a failure to get needed control file fails the RPC."""
1503        self._mockDevServerGetter()
1504
1505        self.dev_server.hostname = 'mox_url'
1506        self.dev_server.stage_artifacts(
1507                image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
1508
1509        self.getter.get_control_file_contents_by_name(
1510            self._SUITE_NAME).AndRaise(error.NoControlFileList())
1511        self.mox.ReplayAll()
1512        self.assertRaises(error.NoControlFileList,
1513                          rpc_interface.create_suite_job,
1514                          name=self._NAME,
1515                          board=self._BOARD,
1516                          builds=self._BUILDS,
1517                          pool=None)
1518
1519
1520    def testCreateSuiteJobFail(self):
1521        """Ensure that failure to schedule the suite job fails the RPC."""
1522        self._mockDevServerGetter()
1523
1524        self.dev_server.hostname = 'mox_url'
1525        self.dev_server.stage_artifacts(
1526                image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
1527
1528        self.getter.get_control_file_contents_by_name(
1529            self._SUITE_NAME).AndReturn('f')
1530
1531        self.dev_server.url().AndReturn('mox_url')
1532        self._mockRpcUtils(-1)
1533        self.mox.ReplayAll()
1534        self.assertEquals(
1535            rpc_interface.create_suite_job(name=self._NAME,
1536                                           board=self._BOARD,
1537                                           builds=self._BUILDS, pool=None),
1538            -1)
1539
1540
1541    def testCreateSuiteJobSuccess(self):
1542        """Ensures that success results in a successful RPC."""
1543        self._mockDevServerGetter()
1544
1545        self.dev_server.hostname = 'mox_url'
1546        self.dev_server.stage_artifacts(
1547                image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
1548
1549        self.getter.get_control_file_contents_by_name(
1550            self._SUITE_NAME).AndReturn('f')
1551
1552        self.dev_server.url().AndReturn('mox_url')
1553        job_id = 5
1554        self._mockRpcUtils(job_id)
1555        self.mox.ReplayAll()
1556        self.assertEquals(
1557            rpc_interface.create_suite_job(name=self._NAME,
1558                                           board=self._BOARD,
1559                                           builds=self._BUILDS,
1560                                           pool=None),
1561            job_id)
1562
1563
1564    def testCreateSuiteJobNoHostCheckSuccess(self):
1565        """Ensures that success results in a successful RPC."""
1566        self._mockDevServerGetter()
1567
1568        self.dev_server.hostname = 'mox_url'
1569        self.dev_server.stage_artifacts(
1570                image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
1571
1572        self.getter.get_control_file_contents_by_name(
1573            self._SUITE_NAME).AndReturn('f')
1574
1575        self.dev_server.url().AndReturn('mox_url')
1576        job_id = 5
1577        self._mockRpcUtils(job_id)
1578        self.mox.ReplayAll()
1579        self.assertEquals(
1580          rpc_interface.create_suite_job(name=self._NAME,
1581                                         board=self._BOARD,
1582                                         builds=self._BUILDS,
1583                                         pool=None, check_hosts=False),
1584          job_id)
1585
1586
1587    def testCreateSuiteJobControlFileSupplied(self):
1588        """Ensure we can supply the control file to create_suite_job."""
1589        self._mockDevServerGetter(get_control_file=False)
1590
1591        self.dev_server.hostname = 'mox_url'
1592        self.dev_server.stage_artifacts(
1593                image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
1594        self.dev_server.url().AndReturn('mox_url')
1595        job_id = 5
1596        self._mockRpcUtils(job_id)
1597        self.mox.ReplayAll()
1598        self.assertEquals(
1599            rpc_interface.create_suite_job(name='%s/%s' % (self._NAME,
1600                                                           self._BUILD),
1601                                           board=None,
1602                                           builds=self._BUILDS,
1603                                           pool=None,
1604                                           control_file='CONTROL FILE'),
1605            job_id)
1606
1607
1608    def _get_records_for_sending_to_master(self):
1609        return [{'control_file': 'foo',
1610                 'control_type': 1,
1611                 'created_on': datetime.datetime(2014, 8, 21),
1612                 'drone_set': None,
1613                 'email_list': '',
1614                 'max_runtime_hrs': 72,
1615                 'max_runtime_mins': 1440,
1616                 'name': 'dummy',
1617                 'owner': 'autotest_system',
1618                 'parse_failed_repair': True,
1619                 'priority': 40,
1620                 'reboot_after': 0,
1621                 'reboot_before': 1,
1622                 'run_reset': True,
1623                 'run_verify': False,
1624                 'synch_count': 0,
1625                 'test_retry': 10,
1626                 'timeout': 24,
1627                 'timeout_mins': 1440,
1628                 'id': 1
1629                 }], [{
1630                    'aborted': False,
1631                    'active': False,
1632                    'complete': False,
1633                    'deleted': False,
1634                    'execution_subdir': '',
1635                    'finished_on': None,
1636                    'started_on': None,
1637                    'status': 'Queued',
1638                    'id': 1
1639                }]
1640
1641
1642    def _send_records_to_master_helper(
1643        self, jobs, hqes, shard_hostname='host1',
1644        exception_to_throw=error.UnallowedRecordsSentToMaster, aborted=False):
1645        job_id = rpc_interface.create_job(
1646                name='dummy',
1647                priority=self._PRIORITY,
1648                control_file='foo',
1649                control_type=SERVER,
1650                test_retry=10, hostless=True)
1651        job = models.Job.objects.get(pk=job_id)
1652        shard = models.Shard.objects.create(hostname='host1')
1653        job.shard = shard
1654        job.save()
1655
1656        if aborted:
1657            job.hostqueueentry_set.update(aborted=True)
1658            job.shard = None
1659            job.save()
1660
1661        hqe = job.hostqueueentry_set.all()[0]
1662        if not exception_to_throw:
1663            self._do_heartbeat_and_assert_response(
1664                shard_hostname=shard_hostname,
1665                upload_jobs=jobs, upload_hqes=hqes)
1666        else:
1667            self.assertRaises(
1668                exception_to_throw,
1669                self._do_heartbeat_and_assert_response,
1670                shard_hostname=shard_hostname,
1671                upload_jobs=jobs, upload_hqes=hqes)
1672
1673
1674    def testSendingRecordsToMaster(self):
1675        """Send records to the master and ensure they are persisted."""
1676        jobs, hqes = self._get_records_for_sending_to_master()
1677        hqes[0]['status'] = 'Completed'
1678        self._send_records_to_master_helper(
1679            jobs=jobs, hqes=hqes, exception_to_throw=None)
1680
1681        # Check the entry was actually written to db
1682        self.assertEqual(models.HostQueueEntry.objects.all()[0].status,
1683                         'Completed')
1684
1685
1686    def testSendingRecordsToMasterAbortedOnMaster(self):
1687        """Send records to the master and ensure they are persisted."""
1688        jobs, hqes = self._get_records_for_sending_to_master()
1689        hqes[0]['status'] = 'Completed'
1690        self._send_records_to_master_helper(
1691            jobs=jobs, hqes=hqes, exception_to_throw=None, aborted=True)
1692
1693        # Check the entry was actually written to db
1694        self.assertEqual(models.HostQueueEntry.objects.all()[0].status,
1695                         'Completed')
1696
1697
1698    def testSendingRecordsToMasterJobAssignedToDifferentShard(self):
1699        """Ensure records belonging to different shard are silently rejected."""
1700        shard1 = models.Shard.objects.create(hostname='shard1')
1701        shard2 = models.Shard.objects.create(hostname='shard2')
1702        job1 = self._create_job(shard=shard1, control_file='foo1')
1703        job2 = self._create_job(shard=shard2, control_file='foo2')
1704        job1_id = job1.id
1705        job2_id = job2.id
1706        hqe1 = models.HostQueueEntry.objects.create(job=job1)
1707        hqe2 = models.HostQueueEntry.objects.create(job=job2)
1708        hqe1_id = hqe1.id
1709        hqe2_id = hqe2.id
1710        job1_record = job1.serialize(include_dependencies=False)
1711        job2_record = job2.serialize(include_dependencies=False)
1712        hqe1_record = hqe1.serialize(include_dependencies=False)
1713        hqe2_record = hqe2.serialize(include_dependencies=False)
1714
1715        # Prepare a bogus job record update from the wrong shard. The update
1716        # should not throw an exception. Non-bogus jobs in the same update
1717        # should happily update.
1718        job1_record.update({'control_file': 'bar1'})
1719        job2_record.update({'control_file': 'bar2'})
1720        hqe1_record.update({'status': 'Aborted'})
1721        hqe2_record.update({'status': 'Aborted'})
1722        self._do_heartbeat_and_assert_response(
1723            shard_hostname='shard2', upload_jobs=[job1_record, job2_record],
1724            upload_hqes=[hqe1_record, hqe2_record])
1725
1726        # Job and HQE record for wrong job should not be modified, because the
1727        # rpc came from the wrong shard. Job and HQE record for valid job are
1728        # modified.
1729        self.assertEqual(models.Job.objects.get(id=job1_id).control_file,
1730                         'foo1')
1731        self.assertEqual(models.Job.objects.get(id=job2_id).control_file,
1732                         'bar2')
1733        self.assertEqual(models.HostQueueEntry.objects.get(id=hqe1_id).status,
1734                         '')
1735        self.assertEqual(models.HostQueueEntry.objects.get(id=hqe2_id).status,
1736                         'Aborted')
1737
1738
1739    def testSendingRecordsToMasterNotExistingJob(self):
1740        """Ensure update for non existing job gets rejected."""
1741        jobs, hqes = self._get_records_for_sending_to_master()
1742        jobs[0]['id'] = 3
1743
1744        self._send_records_to_master_helper(
1745            jobs=jobs, hqes=hqes)
1746
1747
1748    def _createShardAndHostWithLabel(self, shard_hostname='shard1',
1749                                     host_hostname='host1',
1750                                     label_name='board:lumpy'):
1751        """Create a label, host, shard, and assign host to shard."""
1752        try:
1753            label = models.Label.objects.create(name=label_name)
1754        except:
1755            label = models.Label.smart_get(label_name)
1756
1757        shard = models.Shard.objects.create(hostname=shard_hostname)
1758        shard.labels.add(label)
1759
1760        host = models.Host.objects.create(hostname=host_hostname, leased=False,
1761                                          shard=shard)
1762        host.labels.add(label)
1763
1764        return shard, host, label
1765
1766
1767    def testShardLabelRemovalInvalid(self):
1768        """Ensure you cannot remove the wrong label from shard."""
1769        shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
1770        stumpy_label = models.Label.objects.create(
1771                name='board:stumpy', platform=True)
1772        with self.assertRaises(error.RPCException):
1773            rpc_interface.remove_board_from_shard(
1774                    shard1.hostname, stumpy_label.name)
1775
1776
1777    def testShardHeartbeatLabelRemoval(self):
1778        """Ensure label removal from shard works."""
1779        shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
1780
1781        self.assertEqual(host1.shard, shard1)
1782        self.assertItemsEqual(shard1.labels.all(), [lumpy_label])
1783        rpc_interface.remove_board_from_shard(
1784                shard1.hostname, lumpy_label.name)
1785        host1 = models.Host.smart_get(host1.id)
1786        shard1 = models.Shard.smart_get(shard1.id)
1787        self.assertEqual(host1.shard, None)
1788        self.assertItemsEqual(shard1.labels.all(), [])
1789
1790
1791    def testCreateListShard(self):
1792        """Retrieve a list of all shards."""
1793        lumpy_label = models.Label.objects.create(name='board:lumpy',
1794                                                  platform=True)
1795        stumpy_label = models.Label.objects.create(name='board:stumpy',
1796                                                  platform=True)
1797        peppy_label = models.Label.objects.create(name='board:peppy',
1798                                                  platform=True)
1799
1800        shard_id = rpc_interface.add_shard(
1801            hostname='host1', labels='board:lumpy,board:stumpy')
1802        self.assertRaises(error.RPCException,
1803                          rpc_interface.add_shard,
1804                          hostname='host1', labels='board:lumpy,board:stumpy')
1805        self.assertRaises(model_logic.ValidationError,
1806                          rpc_interface.add_shard,
1807                          hostname='host1', labels='board:peppy')
1808        shard = models.Shard.objects.get(pk=shard_id)
1809        self.assertEqual(shard.hostname, 'host1')
1810        self.assertEqual(shard.labels.values_list('pk')[0], (lumpy_label.id,))
1811        self.assertEqual(shard.labels.values_list('pk')[1], (stumpy_label.id,))
1812
1813        self.assertEqual(rpc_interface.get_shards(),
1814                         [{'labels': ['board:lumpy','board:stumpy'],
1815                           'hostname': 'host1',
1816                           'id': 1}])
1817
1818
1819    def testAddBoardsToShard(self):
1820        """Add boards to a given shard."""
1821        shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
1822        stumpy_label = models.Label.objects.create(name='board:stumpy',
1823                                                   platform=True)
1824        shard_id = rpc_interface.add_board_to_shard(
1825            hostname='shard1', labels='board:stumpy')
1826        # Test whether raise exception when board label does not exist.
1827        self.assertRaises(models.Label.DoesNotExist,
1828                          rpc_interface.add_board_to_shard,
1829                          hostname='shard1', labels='board:test')
1830        # Test whether raise exception when board already sharded.
1831        self.assertRaises(error.RPCException,
1832                          rpc_interface.add_board_to_shard,
1833                          hostname='shard1', labels='board:lumpy')
1834        shard = models.Shard.objects.get(pk=shard_id)
1835        self.assertEqual(shard.hostname, 'shard1')
1836        self.assertEqual(shard.labels.values_list('pk')[0], (lumpy_label.id,))
1837        self.assertEqual(shard.labels.values_list('pk')[1], (stumpy_label.id,))
1838
1839        self.assertEqual(rpc_interface.get_shards(),
1840                         [{'labels': ['board:lumpy','board:stumpy'],
1841                           'hostname': 'shard1',
1842                           'id': 1}])
1843
1844
1845    def testShardHeartbeatFetchHostlessJob(self):
1846        shard1, host1, label1 = self._createShardAndHostWithLabel()
1847        self._testShardHeartbeatFetchHostlessJobHelper(host1)
1848
1849
1850    def testShardHeartbeatIncorrectHosts(self):
1851        shard1, host1, label1 = self._createShardAndHostWithLabel()
1852        self._testShardHeartbeatIncorrectHostsHelper(host1)
1853
1854
1855    def testShardHeartbeatLabelRemovalRace(self):
1856        shard1, host1, label1 = self._createShardAndHostWithLabel()
1857        self._testShardHeartbeatLabelRemovalRaceHelper(shard1, host1, label1)
1858
1859
1860    def testShardRetrieveJobs(self):
1861        shard1, host1, label1 = self._createShardAndHostWithLabel()
1862        shard2, host2, label2 = self._createShardAndHostWithLabel(
1863                'shard2', 'host2', 'board:grumpy')
1864        self._testShardRetrieveJobsHelper(shard1, host1, label1,
1865                                          shard2, host2, label2)
1866
1867
1868    def testResendJobsAfterFailedHeartbeat(self):
1869        shard1, host1, label1 = self._createShardAndHostWithLabel()
1870        self._testResendJobsAfterFailedHeartbeatHelper(shard1, host1, label1)
1871
1872
1873    def testResendHostsAfterFailedHeartbeat(self):
1874        shard1, host1, label1 = self._createShardAndHostWithLabel()
1875        self._testResendHostsAfterFailedHeartbeatHelper(host1)
1876
1877
1878if __name__ == '__main__':
1879    unittest.main()
1880