1# Copyright (c) 2014 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import datetime
6import mox
7import unittest
8
9import common
10
11from autotest_lib.frontend import setup_django_environment
12from autotest_lib.frontend.afe import frontend_test_utils
13from autotest_lib.frontend.afe import models
14from autotest_lib.client.common_lib import error
15from autotest_lib.client.common_lib import global_config
16from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
17from autotest_lib.scheduler.shard import shard_client
18
19
20class ShardClientTest(mox.MoxTestBase,
21                      frontend_test_utils.FrontendTestMixin):
22    """Unit tests for functions in shard_client.py"""
23
24
25    GLOBAL_AFE_HOSTNAME = 'foo_autotest'
26
27
28    def setUp(self):
29        super(ShardClientTest, self).setUp()
30
31        global_config.global_config.override_config_value(
32                'SHARD', 'global_afe_hostname', self.GLOBAL_AFE_HOSTNAME)
33
34        self._frontend_common_setup(fill_data=False)
35
36
37    def setup_mocks(self):
38        self.mox.StubOutClassWithMocks(frontend_wrappers, 'RetryingAFE')
39        self.afe = frontend_wrappers.RetryingAFE(server=mox.IgnoreArg(),
40                                                 delay_sec=5,
41                                                 timeout_min=5)
42
43
44    def setup_global_config(self):
45        global_config.global_config.override_config_value(
46                'SHARD', 'is_slave_shard', 'True')
47        global_config.global_config.override_config_value(
48                'SHARD', 'shard_hostname', 'host1')
49
50
51    def expect_heartbeat(self, shard_hostname='host1',
52                         known_job_ids=[], known_host_ids=[],
53                         known_host_statuses=[], hqes=[], jobs=[],
54                         side_effect=None, return_hosts=[], return_jobs=[],
55                         return_suite_keyvals=[]):
56        call = self.afe.run(
57            'shard_heartbeat', shard_hostname=shard_hostname,
58            hqes=hqes, jobs=jobs,
59            known_job_ids=known_job_ids, known_host_ids=known_host_ids,
60            known_host_statuses=known_host_statuses,
61            )
62
63        if side_effect:
64            call = call.WithSideEffects(side_effect)
65
66        call.AndReturn({
67                'hosts': return_hosts,
68                'jobs': return_jobs,
69                'suite_keyvals': return_suite_keyvals,
70            })
71
72
73    def tearDown(self):
74        self._frontend_common_teardown()
75
76        # Without this global_config will keep state over test cases
77        global_config.global_config.reset_config_values()
78
79
80    def _get_sample_serialized_host(self):
81        return {'aclgroup_set': [],
82                'dirty': True,
83                'hostattribute_set': [],
84                'hostname': u'host1',
85                u'id': 2,
86                'invalid': False,
87                'labels': [],
88                'leased': True,
89                'lock_time': None,
90                'locked': False,
91                'protection': 0,
92                'shard': None,
93                'status': u'Ready'}
94
95
96    def _get_sample_serialized_job(self):
97        return {'control_file': u'foo',
98                'control_type': 2,
99                'created_on': datetime.datetime(2014, 9, 23, 15, 56, 10, 0),
100                'dependency_labels': [{u'id': 1,
101                                       'invalid': False,
102                                       'kernel_config': u'',
103                                       'name': u'board:lumpy',
104                                       'only_if_needed': False,
105                                       'platform': False}],
106                'email_list': u'',
107                'hostqueueentry_set': [{'aborted': False,
108                                        'active': False,
109                                        'complete': False,
110                                        'deleted': False,
111                                        'execution_subdir': u'',
112                                        'finished_on': None,
113                                        u'id': 1,
114                                        'meta_host': {u'id': 1,
115                                                      'invalid': False,
116                                                      'kernel_config': u'',
117                                                      'name': u'board:lumpy',
118                                                      'only_if_needed': False,
119                                                      'platform': False},
120                                        'started_on': None,
121                                        'status': u'Queued'}],
122                u'id': 1,
123                'jobkeyval_set': [],
124                'max_runtime_hrs': 72,
125                'max_runtime_mins': 1440,
126                'name': u'dummy',
127                'owner': u'autotest_system',
128                'parse_failed_repair': True,
129                'priority': 40,
130                'parent_job_id': 0,
131                'reboot_after': 0,
132                'reboot_before': 1,
133                'run_reset': True,
134                'run_verify': False,
135                'shard': {'hostname': u'shard1', u'id': 1},
136                'synch_count': 0,
137                'test_retry': 0,
138                'timeout': 24,
139                'timeout_mins': 1440}
140
141
142    def _get_sample_serialized_suite_keyvals(self):
143        return {'id': 1,
144                'job_id': 0,
145                'key': 'test_key',
146                'value': 'test_value'}
147
148
149    def testHeartbeat(self):
150        """Trigger heartbeat, verify RPCs and persisting of the responses."""
151        self.setup_mocks()
152
153        global_config.global_config.override_config_value(
154                'SHARD', 'shard_hostname', 'host1')
155
156        self.expect_heartbeat(
157                return_hosts=[self._get_sample_serialized_host()],
158                return_jobs=[self._get_sample_serialized_job()],
159                return_suite_keyvals=[
160                        self._get_sample_serialized_suite_keyvals()])
161
162        modified_sample_host = self._get_sample_serialized_host()
163        modified_sample_host['hostname'] = 'host2'
164
165        self.expect_heartbeat(
166                return_hosts=[modified_sample_host],
167                known_host_ids=[modified_sample_host['id']],
168                known_host_statuses=[modified_sample_host['status']],
169                known_job_ids=[1])
170
171
172        def verify_upload_jobs_and_hqes(name, shard_hostname, jobs, hqes,
173                                        known_host_ids, known_host_statuses,
174                                        known_job_ids):
175            self.assertEqual(len(jobs), 1)
176            self.assertEqual(len(hqes), 1)
177            job, hqe = jobs[0], hqes[0]
178            self.assertEqual(hqe['status'], 'Completed')
179
180
181        self.expect_heartbeat(
182                jobs=mox.IgnoreArg(), hqes=mox.IgnoreArg(),
183                known_host_ids=[modified_sample_host['id']],
184                known_host_statuses=[modified_sample_host['status']],
185                known_job_ids=[], side_effect=verify_upload_jobs_and_hqes)
186
187        self.mox.ReplayAll()
188        sut = shard_client.get_shard_client()
189
190        sut.do_heartbeat()
191
192        # Check if dummy object was saved to DB
193        host = models.Host.objects.get(id=2)
194        self.assertEqual(host.hostname, 'host1')
195
196        # Check if suite keyval  was saved to DB
197        suite_keyval = models.JobKeyval.objects.filter(job_id=0)[0]
198        self.assertEqual(suite_keyval.key, 'test_key')
199
200        sut.do_heartbeat()
201
202        # Ensure it wasn't overwritten
203        host = models.Host.objects.get(id=2)
204        self.assertEqual(host.hostname, 'host1')
205
206        job = models.Job.objects.all()[0]
207        job.shard = None
208        job.save()
209        hqe = job.hostqueueentry_set.all()[0]
210        hqe.status = 'Completed'
211        hqe.save()
212
213        sut.do_heartbeat()
214
215
216        self.mox.VerifyAll()
217
218
219    def testFailAndRedownloadJobs(self):
220        self.setup_mocks()
221        self.setup_global_config()
222
223        job1_serialized = self._get_sample_serialized_job()
224        job2_serialized = self._get_sample_serialized_job()
225        job2_serialized['id'] = 2
226        job2_serialized['hostqueueentry_set'][0]['id'] = 2
227
228        self.expect_heartbeat(return_jobs=[job1_serialized])
229        self.expect_heartbeat(return_jobs=[job1_serialized, job2_serialized])
230        self.expect_heartbeat(known_job_ids=[job1_serialized['id'],
231                                             job2_serialized['id']])
232        self.expect_heartbeat(known_job_ids=[job2_serialized['id']])
233
234        self.mox.ReplayAll()
235        sut = shard_client.get_shard_client()
236
237        original_process_heartbeat_response = sut.process_heartbeat_response
238        def failing_process_heartbeat_response(*args, **kwargs):
239            raise RuntimeError
240
241        sut.process_heartbeat_response = failing_process_heartbeat_response
242        self.assertRaises(RuntimeError, sut.do_heartbeat)
243
244        sut.process_heartbeat_response = original_process_heartbeat_response
245        sut.do_heartbeat()
246        sut.do_heartbeat()
247
248        job2 = models.Job.objects.get(pk=job1_serialized['id'])
249        job2.hostqueueentry_set.all().update(complete=True)
250
251        sut.do_heartbeat()
252
253        self.mox.VerifyAll()
254
255
256    def testFailAndRedownloadHosts(self):
257        self.setup_mocks()
258        self.setup_global_config()
259
260        host1_serialized = self._get_sample_serialized_host()
261        host2_serialized = self._get_sample_serialized_host()
262        host2_serialized['id'] = 3
263        host2_serialized['hostname'] = 'host2'
264
265        self.expect_heartbeat(return_hosts=[host1_serialized])
266        self.expect_heartbeat(return_hosts=[host1_serialized, host2_serialized])
267        self.expect_heartbeat(known_host_ids=[host1_serialized['id'],
268                                              host2_serialized['id']],
269                              known_host_statuses=[host1_serialized['status'],
270                                                   host2_serialized['status']])
271
272        self.mox.ReplayAll()
273        sut = shard_client.get_shard_client()
274
275        original_process_heartbeat_response = sut.process_heartbeat_response
276        def failing_process_heartbeat_response(*args, **kwargs):
277            raise RuntimeError
278
279        sut.process_heartbeat_response = failing_process_heartbeat_response
280        self.assertRaises(RuntimeError, sut.do_heartbeat)
281
282        self.assertEqual(models.Host.objects.count(), 0)
283
284        sut.process_heartbeat_response = original_process_heartbeat_response
285        sut.do_heartbeat()
286        sut.do_heartbeat()
287
288        self.mox.VerifyAll()
289
290
291    def testHeartbeatNoShardMode(self):
292        """Ensure an exception is thrown when run on a non-shard machine."""
293        self.mox.ReplayAll()
294
295        self.assertRaises(error.HeartbeatOnlyAllowedInShardModeException,
296                          shard_client.get_shard_client)
297
298        self.mox.VerifyAll()
299
300
301    def testLoop(self):
302        """Test looping over heartbeats and aborting that loop works."""
303        self.setup_mocks()
304        self.setup_global_config()
305
306        global_config.global_config.override_config_value(
307                'SHARD', 'heartbeat_pause_sec', '0.01')
308
309        self.expect_heartbeat()
310
311        sut = None
312
313        def shutdown_sut(*args, **kwargs):
314            sut.shutdown()
315
316        self.expect_heartbeat(side_effect=shutdown_sut)
317
318        self.mox.ReplayAll()
319        sut = shard_client.get_shard_client()
320        sut.loop()
321
322        self.mox.VerifyAll()
323
324
325if __name__ == '__main__':
326    unittest.main()
327