1# Copyright (c) 2014 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import datetime
6import mox
7import unittest
8
9import common
10
11from autotest_lib.frontend import setup_django_environment
12from autotest_lib.frontend.afe import frontend_test_utils
13from autotest_lib.frontend.afe import models
14from autotest_lib.client.common_lib import error
15from autotest_lib.client.common_lib import global_config
16from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
17from autotest_lib.scheduler.shard import shard_client
18
19
20class ShardClientTest(mox.MoxTestBase,
21                      frontend_test_utils.FrontendTestMixin):
22    """Unit tests for functions in shard_client.py"""
23
24
25    GLOBAL_AFE_HOSTNAME = 'foo_autotest'
26
27
28    def setUp(self):
29        super(ShardClientTest, self).setUp()
30
31        global_config.global_config.override_config_value(
32                'SHARD', 'global_afe_hostname', self.GLOBAL_AFE_HOSTNAME)
33
34        self._frontend_common_setup(fill_data=False)
35
36
37    def setup_mocks(self):
38        self.mox.StubOutClassWithMocks(frontend_wrappers, 'RetryingAFE')
39        self.afe = frontend_wrappers.RetryingAFE(server=mox.IgnoreArg(),
40                                                 delay_sec=5,
41                                                 timeout_min=5)
42
43
44    def setup_global_config(self):
45        global_config.global_config.override_config_value(
46                'SHARD', 'is_slave_shard', 'True')
47        global_config.global_config.override_config_value(
48                'SHARD', 'shard_hostname', 'host1')
49
50
51    def expect_heartbeat(self, shard_hostname='host1',
52                         known_job_ids=[], known_host_ids=[],
53                         known_host_statuses=[], hqes=[], jobs=[],
54                         side_effect=None, return_hosts=[], return_jobs=[],
55                         return_suite_keyvals=[]):
56        call = self.afe.run(
57            'shard_heartbeat', shard_hostname=shard_hostname,
58            hqes=hqes, jobs=jobs,
59            known_job_ids=known_job_ids, known_host_ids=known_host_ids,
60            known_host_statuses=known_host_statuses,
61            )
62
63        if side_effect:
64            call = call.WithSideEffects(side_effect)
65
66        call.AndReturn({
67                'hosts': return_hosts,
68                'jobs': return_jobs,
69                'suite_keyvals': return_suite_keyvals,
70            })
71
72
73    def tearDown(self):
74        self._frontend_common_teardown()
75
76        # Without this global_config will keep state over test cases
77        global_config.global_config.reset_config_values()
78
79
80    def _get_sample_serialized_host(self):
81        return {'aclgroup_set': [],
82                'dirty': True,
83                'hostattribute_set': [],
84                'hostname': u'host1',
85                u'id': 2,
86                'invalid': False,
87                'labels': [],
88                'leased': True,
89                'lock_time': None,
90                'locked': False,
91                'protection': 0,
92                'shard': None,
93                'status': u'Ready',
94                'synch_id': None}
95
96
97    def _get_sample_serialized_job(self):
98        return {'control_file': u'foo',
99                'control_type': 2,
100                'created_on': datetime.datetime(2014, 9, 23, 15, 56, 10, 0),
101                'dependency_labels': [{u'id': 1,
102                                       'invalid': False,
103                                       'kernel_config': u'',
104                                       'name': u'board:lumpy',
105                                       'only_if_needed': False,
106                                       'platform': False}],
107                'email_list': u'',
108                'hostqueueentry_set': [{'aborted': False,
109                                        'active': False,
110                                        'complete': False,
111                                        'deleted': False,
112                                        'execution_subdir': u'',
113                                        'finished_on': None,
114                                        u'id': 1,
115                                        'meta_host': {u'id': 1,
116                                                      'invalid': False,
117                                                      'kernel_config': u'',
118                                                      'name': u'board:lumpy',
119                                                      'only_if_needed': False,
120                                                      'platform': False},
121                                        'started_on': None,
122                                        'status': u'Queued'}],
123                u'id': 1,
124                'jobkeyval_set': [],
125                'max_runtime_hrs': 72,
126                'max_runtime_mins': 1440,
127                'name': u'dummy',
128                'owner': u'autotest_system',
129                'parse_failed_repair': True,
130                'priority': 40,
131                'parent_job_id': 0,
132                'reboot_after': 0,
133                'reboot_before': 1,
134                'run_reset': True,
135                'run_verify': False,
136                'shard': {'hostname': u'shard1', u'id': 1},
137                'synch_count': 0,
138                'test_retry': 0,
139                'timeout': 24,
140                'timeout_mins': 1440}
141
142
143    def _get_sample_serialized_suite_keyvals(self):
144        return {'id': 1,
145                'job_id': 0,
146                'key': 'test_key',
147                'value': 'test_value'}
148
149
150    def testHeartbeat(self):
151        """Trigger heartbeat, verify RPCs and persisting of the responses."""
152        self.setup_mocks()
153
154        global_config.global_config.override_config_value(
155                'SHARD', 'shard_hostname', 'host1')
156
157        self.expect_heartbeat(
158                return_hosts=[self._get_sample_serialized_host()],
159                return_jobs=[self._get_sample_serialized_job()],
160                return_suite_keyvals=[
161                        self._get_sample_serialized_suite_keyvals()])
162
163        modified_sample_host = self._get_sample_serialized_host()
164        modified_sample_host['hostname'] = 'host2'
165
166        self.expect_heartbeat(
167                return_hosts=[modified_sample_host],
168                known_host_ids=[modified_sample_host['id']],
169                known_host_statuses=[modified_sample_host['status']],
170                known_job_ids=[1])
171
172
173        def verify_upload_jobs_and_hqes(name, shard_hostname, jobs, hqes,
174                                        known_host_ids, known_host_statuses,
175                                        known_job_ids):
176            self.assertEqual(len(jobs), 1)
177            self.assertEqual(len(hqes), 1)
178            job, hqe = jobs[0], hqes[0]
179            self.assertEqual(hqe['status'], 'Completed')
180
181
182        self.expect_heartbeat(
183                jobs=mox.IgnoreArg(), hqes=mox.IgnoreArg(),
184                known_host_ids=[modified_sample_host['id']],
185                known_host_statuses=[modified_sample_host['status']],
186                known_job_ids=[], side_effect=verify_upload_jobs_and_hqes)
187
188        self.mox.ReplayAll()
189        sut = shard_client.get_shard_client()
190
191        sut.do_heartbeat()
192
193        # Check if dummy object was saved to DB
194        host = models.Host.objects.get(id=2)
195        self.assertEqual(host.hostname, 'host1')
196
197        # Check if suite keyval  was saved to DB
198        suite_keyval = models.JobKeyval.objects.filter(job_id=0)[0]
199        self.assertEqual(suite_keyval.key, 'test_key')
200
201        sut.do_heartbeat()
202
203        # Ensure it wasn't overwritten
204        host = models.Host.objects.get(id=2)
205        self.assertEqual(host.hostname, 'host1')
206
207        job = models.Job.objects.all()[0]
208        job.shard = None
209        job.save()
210        hqe = job.hostqueueentry_set.all()[0]
211        hqe.status = 'Completed'
212        hqe.save()
213
214        sut.do_heartbeat()
215
216
217        self.mox.VerifyAll()
218
219
220    def testFailAndRedownloadJobs(self):
221        self.setup_mocks()
222        self.setup_global_config()
223
224        job1_serialized = self._get_sample_serialized_job()
225        job2_serialized = self._get_sample_serialized_job()
226        job2_serialized['id'] = 2
227        job2_serialized['hostqueueentry_set'][0]['id'] = 2
228
229        self.expect_heartbeat(return_jobs=[job1_serialized])
230        self.expect_heartbeat(return_jobs=[job1_serialized, job2_serialized])
231        self.expect_heartbeat(known_job_ids=[job1_serialized['id'],
232                                             job2_serialized['id']])
233        self.expect_heartbeat(known_job_ids=[job2_serialized['id']])
234
235        self.mox.ReplayAll()
236        sut = shard_client.get_shard_client()
237
238        original_process_heartbeat_response = sut.process_heartbeat_response
239        def failing_process_heartbeat_response(*args, **kwargs):
240            raise RuntimeError
241
242        sut.process_heartbeat_response = failing_process_heartbeat_response
243        self.assertRaises(RuntimeError, sut.do_heartbeat)
244
245        sut.process_heartbeat_response = original_process_heartbeat_response
246        sut.do_heartbeat()
247        sut.do_heartbeat()
248
249        job2 = models.Job.objects.get(pk=job1_serialized['id'])
250        job2.hostqueueentry_set.all().update(complete=True)
251
252        sut.do_heartbeat()
253
254        self.mox.VerifyAll()
255
256
257    def testFailAndRedownloadHosts(self):
258        self.setup_mocks()
259        self.setup_global_config()
260
261        host1_serialized = self._get_sample_serialized_host()
262        host2_serialized = self._get_sample_serialized_host()
263        host2_serialized['id'] = 3
264        host2_serialized['hostname'] = 'host2'
265
266        self.expect_heartbeat(return_hosts=[host1_serialized])
267        self.expect_heartbeat(return_hosts=[host1_serialized, host2_serialized])
268        self.expect_heartbeat(known_host_ids=[host1_serialized['id'],
269                                              host2_serialized['id']],
270                              known_host_statuses=[host1_serialized['status'],
271                                                   host2_serialized['status']])
272
273        self.mox.ReplayAll()
274        sut = shard_client.get_shard_client()
275
276        original_process_heartbeat_response = sut.process_heartbeat_response
277        def failing_process_heartbeat_response(*args, **kwargs):
278            raise RuntimeError
279
280        sut.process_heartbeat_response = failing_process_heartbeat_response
281        self.assertRaises(RuntimeError, sut.do_heartbeat)
282
283        self.assertEqual(models.Host.objects.count(), 0)
284
285        sut.process_heartbeat_response = original_process_heartbeat_response
286        sut.do_heartbeat()
287        sut.do_heartbeat()
288
289        self.mox.VerifyAll()
290
291
292    def testHeartbeatNoShardMode(self):
293        """Ensure an exception is thrown when run on a non-shard machine."""
294        self.mox.ReplayAll()
295
296        self.assertRaises(error.HeartbeatOnlyAllowedInShardModeException,
297                          shard_client.get_shard_client)
298
299        self.mox.VerifyAll()
300
301
302    def testLoop(self):
303        """Test looping over heartbeats and aborting that loop works."""
304        self.setup_mocks()
305        self.setup_global_config()
306
307        global_config.global_config.override_config_value(
308                'SHARD', 'heartbeat_pause_sec', '0.01')
309
310        self.expect_heartbeat()
311
312        sut = None
313
314        def shutdown_sut(*args, **kwargs):
315            sut.shutdown()
316
317        self.expect_heartbeat(side_effect=shutdown_sut)
318
319        self.mox.ReplayAll()
320        sut = shard_client.get_shard_client()
321        sut.loop()
322
323        self.mox.VerifyAll()
324
325
326if __name__ == '__main__':
327    unittest.main()
328