1# Copyright (c) 2014 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import datetime
6import mox
7import time
8import unittest
9
10import common
11
12from autotest_lib.frontend import setup_django_environment
13from autotest_lib.frontend.afe import frontend_test_utils
14from autotest_lib.frontend.afe import models
15from autotest_lib.frontend.afe import model_logic
16from autotest_lib.client.common_lib import error
17from autotest_lib.client.common_lib import global_config
18from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
19from autotest_lib.scheduler.shard import shard_client
20from django.core.exceptions import MultipleObjectsReturned
21
22
23class ShardClientTest(mox.MoxTestBase,
24                      frontend_test_utils.FrontendTestMixin):
25    """Unit tests for functions in shard_client.py"""
26
27
28    GLOBAL_AFE_HOSTNAME = 'foo_autotest'
29
30
31    def setUp(self):
32        super(ShardClientTest, self).setUp()
33
34        global_config.global_config.override_config_value(
35                'SHARD', 'global_afe_hostname', self.GLOBAL_AFE_HOSTNAME)
36
37        self._frontend_common_setup(fill_data=False)
38
39
40    def tearDown(self):
41        self.mox.UnsetStubs()
42
43
44    def setup_mocks(self):
45        self.mox.StubOutClassWithMocks(frontend_wrappers, 'RetryingAFE')
46        self.afe = frontend_wrappers.RetryingAFE(server=mox.IgnoreArg(),
47                                                 delay_sec=mox.IgnoreArg(),
48                                                 timeout_min=mox.IgnoreArg())
49
50    def setup_global_config(self):
51        global_config.global_config.override_config_value(
52                'SHARD', 'is_slave_shard', 'True')
53        global_config.global_config.override_config_value(
54                'SHARD', 'shard_hostname', 'host1')
55
56
57    def expect_heartbeat(self, shard_hostname='host1',
58                         known_job_ids=[], known_host_ids=[],
59                         known_host_statuses=[], hqes=[], jobs=[],
60                         side_effect=None, return_hosts=[], return_jobs=[],
61                         return_suite_keyvals=[], return_incorrect_hosts=[]):
62        call = self.afe.run(
63            'shard_heartbeat', shard_hostname=shard_hostname,
64            hqes=hqes, jobs=jobs,
65            known_job_ids=known_job_ids, known_host_ids=known_host_ids,
66            known_host_statuses=known_host_statuses,
67            )
68
69        if side_effect:
70            call = call.WithSideEffects(side_effect)
71
72        call.AndReturn({
73                'hosts': return_hosts,
74                'jobs': return_jobs,
75                'suite_keyvals': return_suite_keyvals,
76                'incorrect_host_ids': return_incorrect_hosts,
77            })
78
79
80    def tearDown(self):
81        self._frontend_common_teardown()
82
83        # Without this global_config will keep state over test cases
84        global_config.global_config.reset_config_values()
85
86
87    def _get_sample_serialized_host(self):
88        return {'aclgroup_set': [],
89                'dirty': True,
90                'hostattribute_set': [],
91                'hostname': u'host1',
92                u'id': 2,
93                'invalid': False,
94                'labels': [],
95                'leased': True,
96                'lock_time': None,
97                'locked': False,
98                'protection': 0,
99                'shard': None,
100                'status': u'Ready'}
101
102
103    def _get_sample_serialized_job(self):
104        return {'control_file': u'foo',
105                'control_type': 2,
106                'created_on': datetime.datetime(2014, 9, 23, 15, 56, 10, 0),
107                'dependency_labels': [{u'id': 1,
108                                       'invalid': False,
109                                       'kernel_config': u'',
110                                       'name': u'board:lumpy',
111                                       'only_if_needed': False,
112                                       'platform': False}],
113                'email_list': u'',
114                'hostqueueentry_set': [{'aborted': False,
115                                        'active': False,
116                                        'complete': False,
117                                        'deleted': False,
118                                        'execution_subdir': u'',
119                                        'finished_on': None,
120                                        u'id': 1,
121                                        'meta_host': {u'id': 1,
122                                                      'invalid': False,
123                                                      'kernel_config': u'',
124                                                      'name': u'board:lumpy',
125                                                      'only_if_needed': False,
126                                                      'platform': False},
127                                        'started_on': None,
128                                        'status': u'Queued'}],
129                u'id': 1,
130                'jobkeyval_set': [],
131                'max_runtime_hrs': 72,
132                'max_runtime_mins': 1440,
133                'name': u'dummy',
134                'owner': u'autotest_system',
135                'parse_failed_repair': True,
136                'priority': 40,
137                'parent_job_id': 0,
138                'reboot_after': 0,
139                'reboot_before': 1,
140                'run_reset': True,
141                'run_verify': False,
142                'shard': {'hostname': u'shard1', u'id': 1},
143                'synch_count': 0,
144                'test_retry': 0,
145                'timeout': 24,
146                'timeout_mins': 1440}
147
148
149    def _get_sample_serialized_suite_keyvals(self):
150        return {'id': 1,
151                'job_id': 0,
152                'key': 'test_key',
153                'value': 'test_value'}
154
155
156    def testHeartbeat(self):
157        """Trigger heartbeat, verify RPCs and persisting of the responses."""
158        self.setup_mocks()
159
160        global_config.global_config.override_config_value(
161                'SHARD', 'shard_hostname', 'host1')
162
163        self.expect_heartbeat(
164                return_hosts=[self._get_sample_serialized_host()],
165                return_jobs=[self._get_sample_serialized_job()],
166                return_suite_keyvals=[
167                        self._get_sample_serialized_suite_keyvals()])
168
169        modified_sample_host = self._get_sample_serialized_host()
170        modified_sample_host['hostname'] = 'host2'
171
172        self.expect_heartbeat(
173                return_hosts=[modified_sample_host],
174                known_host_ids=[modified_sample_host['id']],
175                known_host_statuses=[modified_sample_host['status']],
176                known_job_ids=[1])
177
178
179        def verify_upload_jobs_and_hqes(name, shard_hostname, jobs, hqes,
180                                        known_host_ids, known_host_statuses,
181                                        known_job_ids):
182            self.assertEqual(len(jobs), 1)
183            self.assertEqual(len(hqes), 1)
184            job, hqe = jobs[0], hqes[0]
185            self.assertEqual(hqe['status'], 'Completed')
186
187
188        self.expect_heartbeat(
189                jobs=mox.IgnoreArg(), hqes=mox.IgnoreArg(),
190                known_host_ids=[modified_sample_host['id']],
191                known_host_statuses=[modified_sample_host['status']],
192                known_job_ids=[], side_effect=verify_upload_jobs_and_hqes)
193
194        self.mox.ReplayAll()
195        sut = shard_client.get_shard_client()
196
197        sut.do_heartbeat()
198
199        # Check if dummy object was saved to DB
200        host = models.Host.objects.get(id=2)
201        self.assertEqual(host.hostname, 'host1')
202
203        # Check if suite keyval  was saved to DB
204        suite_keyval = models.JobKeyval.objects.filter(job_id=0)[0]
205        self.assertEqual(suite_keyval.key, 'test_key')
206
207        sut.do_heartbeat()
208
209        # Ensure it wasn't overwritten
210        host = models.Host.objects.get(id=2)
211        self.assertEqual(host.hostname, 'host1')
212
213        job = models.Job.objects.all()[0]
214        job.shard = None
215        job.save()
216        hqe = job.hostqueueentry_set.all()[0]
217        hqe.status = 'Completed'
218        hqe.save()
219
220        sut.do_heartbeat()
221
222
223        self.mox.VerifyAll()
224
225
226    def testRemoveInvalidHosts(self):
227        self.setup_mocks()
228        self.setup_global_config()
229
230        host_serialized = self._get_sample_serialized_host()
231        host_id = host_serialized[u'id']
232
233        # 1st heartbeat: return a host.
234        # 2nd heartbeat: "delete" that host. Also send a spurious extra ID
235        # that isn't present to ensure shard client doesn't crash. (Note: delete
236        # operation doesn't actually delete db entry. Djanjo model ;logic
237        # instead simply marks it as invalid.
238        # 3rd heartbeat: host is no longer present in shard's request.
239
240        self.expect_heartbeat(return_hosts=[host_serialized])
241        self.expect_heartbeat(known_host_ids=[host_id],
242                              known_host_statuses=[u'Ready'],
243                              return_incorrect_hosts=[host_id, 42])
244        self.expect_heartbeat()
245
246        self.mox.ReplayAll()
247        sut = shard_client.get_shard_client()
248
249        sut.do_heartbeat()
250        host = models.Host.smart_get(host_id)
251        self.assertFalse(host.invalid)
252
253        # Host should no longer "exist" after the invalidation.
254        # Why don't we simply count the number of hosts in db? Because the host
255        # actually remains int he db, but simply has it's invalid bit set to
256        # True.
257        sut.do_heartbeat()
258        with self.assertRaises(models.Host.DoesNotExist):
259            host = models.Host.smart_get(host_id)
260
261
262        # Subsequent heartbeat no longer passes the host id as a known host.
263        sut.do_heartbeat()
264
265
266    def testFailAndRedownloadJobs(self):
267        self.setup_mocks()
268        self.setup_global_config()
269
270        job1_serialized = self._get_sample_serialized_job()
271        job2_serialized = self._get_sample_serialized_job()
272        job2_serialized['id'] = 2
273        job2_serialized['hostqueueentry_set'][0]['id'] = 2
274
275        self.expect_heartbeat(return_jobs=[job1_serialized])
276        self.expect_heartbeat(return_jobs=[job1_serialized, job2_serialized])
277        self.expect_heartbeat(known_job_ids=[job1_serialized['id'],
278                                             job2_serialized['id']])
279        self.expect_heartbeat(known_job_ids=[job2_serialized['id']])
280
281        self.mox.ReplayAll()
282        sut = shard_client.get_shard_client()
283
284        original_process_heartbeat_response = sut.process_heartbeat_response
285        def failing_process_heartbeat_response(*args, **kwargs):
286            raise RuntimeError
287
288        sut.process_heartbeat_response = failing_process_heartbeat_response
289        self.assertRaises(RuntimeError, sut.do_heartbeat)
290
291        sut.process_heartbeat_response = original_process_heartbeat_response
292        sut.do_heartbeat()
293        sut.do_heartbeat()
294
295        job2 = models.Job.objects.get(pk=job1_serialized['id'])
296        job2.hostqueueentry_set.all().update(complete=True)
297
298        sut.do_heartbeat()
299
300        self.mox.VerifyAll()
301
302
303    def testFailAndRedownloadHosts(self):
304        self.setup_mocks()
305        self.setup_global_config()
306
307        host1_serialized = self._get_sample_serialized_host()
308        host2_serialized = self._get_sample_serialized_host()
309        host2_serialized['id'] = 3
310        host2_serialized['hostname'] = 'host2'
311
312        self.expect_heartbeat(return_hosts=[host1_serialized])
313        self.expect_heartbeat(return_hosts=[host1_serialized, host2_serialized])
314        self.expect_heartbeat(known_host_ids=[host1_serialized['id'],
315                                              host2_serialized['id']],
316                              known_host_statuses=[host1_serialized['status'],
317                                                   host2_serialized['status']])
318
319        self.mox.ReplayAll()
320        sut = shard_client.get_shard_client()
321
322        original_process_heartbeat_response = sut.process_heartbeat_response
323        def failing_process_heartbeat_response(*args, **kwargs):
324            raise RuntimeError
325
326        sut.process_heartbeat_response = failing_process_heartbeat_response
327        self.assertRaises(RuntimeError, sut.do_heartbeat)
328
329        self.assertEqual(models.Host.objects.count(), 0)
330
331        sut.process_heartbeat_response = original_process_heartbeat_response
332        sut.do_heartbeat()
333        sut.do_heartbeat()
334
335        self.mox.VerifyAll()
336
337
338    def testHeartbeatNoShardMode(self):
339        """Ensure an exception is thrown when run on a non-shard machine."""
340        self.mox.ReplayAll()
341
342        self.assertRaises(error.HeartbeatOnlyAllowedInShardModeException,
343                          shard_client.get_shard_client)
344
345        self.mox.VerifyAll()
346
347
348    def testLoop(self):
349        """Test looping over heartbeats and aborting that loop works."""
350        self.setup_mocks()
351        self.setup_global_config()
352
353        global_config.global_config.override_config_value(
354                'SHARD', 'heartbeat_pause_sec', '0.01')
355
356        self.expect_heartbeat()
357
358        sut = None
359
360        def shutdown_sut(*args, **kwargs):
361            sut.shutdown()
362
363        self.expect_heartbeat(side_effect=shutdown_sut)
364
365        self.mox.ReplayAll()
366        sut = shard_client.get_shard_client()
367        sut.loop(None)
368
369        self.mox.VerifyAll()
370
371
372    def testLoopWithDeadline(self):
373        """Test looping over heartbeats with a timeout."""
374        self.setup_mocks()
375        self.setup_global_config()
376        self.mox.StubOutWithMock(time, 'time')
377
378        global_config.global_config.override_config_value(
379                'SHARD', 'heartbeat_pause_sec', '0.01')
380        time.time().AndReturn(1516894000)
381        time.time().AndReturn(1516894000)
382        self.expect_heartbeat()
383        # Set expectation that heartbeat took 1 minute.
384        time.time().MultipleTimes().AndReturn(1516894000 + 60)
385
386        self.mox.ReplayAll()
387        sut = shard_client.get_shard_client()
388        # 36 seconds
389        sut.loop(lifetime_hours=0.01)
390        self.mox.VerifyAll()
391
392    def test_remove_incorrect_hosts(self):
393        """Test _remove_incorrect_hosts with MultipleObjectsReturned."""
394        self.setup_mocks()
395        self.setup_global_config()
396        self.mox.StubOutWithMock(model_logic.ModelWithInvalidQuerySet, 'delete')
397        call = models.Host.objects.filter(id__in=[1]).delete()
398        call.AndRaise(MultipleObjectsReturned('e'))
399
400        self.mox.ReplayAll()
401        sut = shard_client.get_shard_client()
402        sut._remove_incorrect_hosts(incorrect_host_ids=[1])
403
404        self.mox.VerifyAll()
405
406
407if __name__ == '__main__':
408    unittest.main()
409