1#!/usr/bin/python 2#pylint: disable-msg=C0111 3 4import datetime 5import unittest 6 7import common 8from autotest_lib.frontend import setup_django_environment 9from autotest_lib.frontend.afe import frontend_test_utils 10from autotest_lib.client.common_lib import host_queue_entry_states 11from autotest_lib.client.common_lib.test_utils import mock 12from autotest_lib.database import database_connection 13from autotest_lib.frontend.afe import models, model_attributes 14from autotest_lib.scheduler import monitor_db 15from autotest_lib.scheduler import scheduler_lib 16from autotest_lib.scheduler import scheduler_models 17 18_DEBUG = False 19 20 21class BaseSchedulerModelsTest(unittest.TestCase, 22 frontend_test_utils.FrontendTestMixin): 23 _config_section = 'AUTOTEST_WEB' 24 25 def _do_query(self, sql): 26 self._database.execute(sql) 27 28 29 def _set_monitor_stubs(self): 30 # Clear the instance cache as this is a brand new database. 31 scheduler_models.DBObject._clear_instance_cache() 32 33 self._database = ( 34 database_connection.TranslatingDatabase.get_test_database( 35 translators=scheduler_lib._DB_TRANSLATORS)) 36 self._database.connect(db_type='django') 37 self._database.debug = _DEBUG 38 39 self.god.stub_with(scheduler_models, '_db', self._database) 40 41 42 def setUp(self): 43 self._frontend_common_setup() 44 self._set_monitor_stubs() 45 46 47 def tearDown(self): 48 self._database.disconnect() 49 self._frontend_common_teardown() 50 51 52 def _update_hqe(self, set, where=''): 53 query = 'UPDATE afe_host_queue_entries SET ' + set 54 if where: 55 query += ' WHERE ' + where 56 self._do_query(query) 57 58 59class DelayedCallTaskTest(unittest.TestCase): 60 def setUp(self): 61 self.god = mock.mock_god() 62 63 64 def tearDown(self): 65 self.god.unstub_all() 66 67 68 def test_delayed_call(self): 69 test_time = self.god.create_mock_function('time') 70 test_time.expect_call().and_return(33) 71 test_time.expect_call().and_return(34.01) 72 test_time.expect_call().and_return(34.99) 73 test_time.expect_call().and_return(35.01) 74 def test_callback(): 75 test_callback.calls += 1 76 test_callback.calls = 0 77 delay_task = scheduler_models.DelayedCallTask( 78 delay_seconds=2, callback=test_callback, 79 now_func=test_time) # time 33 80 self.assertEqual(35, delay_task.end_time) 81 delay_task.poll() # activates the task and polls it once, time 34.01 82 self.assertEqual(0, test_callback.calls, "callback called early") 83 delay_task.poll() # time 34.99 84 self.assertEqual(0, test_callback.calls, "callback called early") 85 delay_task.poll() # time 35.01 86 self.assertEqual(1, test_callback.calls) 87 self.assert_(delay_task.is_done()) 88 self.assert_(delay_task.success) 89 self.assert_(not delay_task.aborted) 90 self.god.check_playback() 91 92 93 def test_delayed_call_abort(self): 94 delay_task = scheduler_models.DelayedCallTask( 95 delay_seconds=987654, callback=lambda : None) 96 delay_task.abort() 97 self.assert_(delay_task.aborted) 98 self.assert_(delay_task.is_done()) 99 self.assert_(not delay_task.success) 100 self.god.check_playback() 101 102 103class DBObjectTest(BaseSchedulerModelsTest): 104 def test_compare_fields_in_row(self): 105 host = scheduler_models.Host(id=1) 106 fields = list(host._fields) 107 row_data = [getattr(host, fieldname) for fieldname in fields] 108 self.assertEqual({}, host._compare_fields_in_row(row_data)) 109 row_data[fields.index('hostname')] = 'spam' 110 self.assertEqual({'hostname': ('host1', 'spam')}, 111 host._compare_fields_in_row(row_data)) 112 row_data[fields.index('id')] = 23 113 self.assertEqual({'hostname': ('host1', 'spam'), 'id': (1, 23)}, 114 host._compare_fields_in_row(row_data)) 115 116 117 def test_compare_fields_in_row_datetime_ignores_microseconds(self): 118 datetime_with_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 7890) 119 datetime_without_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 0) 120 class TestTable(scheduler_models.DBObject): 121 _table_name = 'test_table' 122 _fields = ('id', 'test_datetime') 123 tt = TestTable(row=[1, datetime_without_us]) 124 self.assertEqual({}, tt._compare_fields_in_row([1, datetime_with_us])) 125 126 127 def test_always_query(self): 128 host_a = scheduler_models.Host(id=2) 129 self.assertEqual(host_a.hostname, 'host2') 130 self._do_query('UPDATE afe_hosts SET hostname="host2-updated" ' 131 'WHERE id=2') 132 host_b = scheduler_models.Host(id=2, always_query=True) 133 self.assert_(host_a is host_b, 'Cached instance not returned.') 134 self.assertEqual(host_a.hostname, 'host2-updated', 135 'Database was not re-queried') 136 137 # If either of these are called, a query was made when it shouldn't be. 138 host_a._compare_fields_in_row = lambda _: self.fail('eek! a query!') 139 host_a._update_fields_from_row = host_a._compare_fields_in_row 140 host_c = scheduler_models.Host(id=2, always_query=False) 141 self.assert_(host_a is host_c, 'Cached instance not returned') 142 143 144 def test_delete(self): 145 host = scheduler_models.Host(id=3) 146 host.delete() 147 host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3, 148 always_query=False) 149 host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3, 150 always_query=True) 151 152 def test_save(self): 153 # Dummy Job to avoid creating a one in the HostQueueEntry __init__. 154 class MockJob(object): 155 def __init__(self, id): 156 pass 157 def tag(self): 158 return 'MockJob' 159 self.god.stub_with(scheduler_models, 'Job', MockJob) 160 hqe = scheduler_models.HostQueueEntry( 161 new_record=True, 162 row=[0, 1, 2, 'Queued', None, 0, 0, 0, '.', None, False, None, 163 None]) 164 hqe.save() 165 new_id = hqe.id 166 # Force a re-query and verify that the correct data was stored. 167 scheduler_models.DBObject._clear_instance_cache() 168 hqe = scheduler_models.HostQueueEntry(id=new_id) 169 self.assertEqual(hqe.id, new_id) 170 self.assertEqual(hqe.job_id, 1) 171 self.assertEqual(hqe.host_id, 2) 172 self.assertEqual(hqe.status, 'Queued') 173 self.assertEqual(hqe.meta_host, None) 174 self.assertEqual(hqe.active, False) 175 self.assertEqual(hqe.complete, False) 176 self.assertEqual(hqe.deleted, False) 177 self.assertEqual(hqe.execution_subdir, '.') 178 self.assertEqual(hqe.started_on, None) 179 self.assertEqual(hqe.finished_on, None) 180 181 182class HostTest(BaseSchedulerModelsTest): 183 def test_cmp_for_sort(self): 184 expected_order = [ 185 'alice', 'Host1', 'host2', 'host3', 'host09', 'HOST010', 186 'host10', 'host11', 'yolkfolk'] 187 hostname_idx = list(scheduler_models.Host._fields).index('hostname') 188 row = [None] * len(scheduler_models.Host._fields) 189 hosts = [] 190 for hostname in expected_order: 191 row[hostname_idx] = hostname 192 hosts.append(scheduler_models.Host(row=row, new_record=True)) 193 194 host1 = hosts[expected_order.index('Host1')] 195 host010 = hosts[expected_order.index('HOST010')] 196 host10 = hosts[expected_order.index('host10')] 197 host3 = hosts[expected_order.index('host3')] 198 alice = hosts[expected_order.index('alice')] 199 self.assertEqual(0, scheduler_models.Host.cmp_for_sort(host10, host10)) 200 self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host10, host010)) 201 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host010, host10)) 202 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host10)) 203 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host010)) 204 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host10)) 205 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host010)) 206 self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, host1)) 207 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host3)) 208 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(alice, host3)) 209 self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, alice)) 210 self.assertEqual(0, scheduler_models.Host.cmp_for_sort(alice, alice)) 211 212 hosts.sort(cmp=scheduler_models.Host.cmp_for_sort) 213 self.assertEqual(expected_order, [h.hostname for h in hosts]) 214 215 hosts.reverse() 216 hosts.sort(cmp=scheduler_models.Host.cmp_for_sort) 217 self.assertEqual(expected_order, [h.hostname for h in hosts]) 218 219 220class HostQueueEntryTest(BaseSchedulerModelsTest): 221 def _create_hqe(self, dependency_labels=(), **create_job_kwargs): 222 job = self._create_job(**create_job_kwargs) 223 for label in dependency_labels: 224 job.dependency_labels.add(label) 225 hqes = list(scheduler_models.HostQueueEntry.fetch(where='job_id=%d' % job.id)) 226 self.assertEqual(1, len(hqes)) 227 return hqes[0] 228 229 230 def _check_hqe_labels(self, hqe, expected_labels): 231 expected_labels = set(expected_labels) 232 label_names = set(label.name for label in hqe.get_labels()) 233 self.assertEqual(expected_labels, label_names) 234 235 236 def test_get_labels_empty(self): 237 hqe = self._create_hqe(hosts=[1]) 238 labels = list(hqe.get_labels()) 239 self.assertEqual([], labels) 240 241 242 def test_get_labels_metahost(self): 243 hqe = self._create_hqe(metahosts=[2]) 244 self._check_hqe_labels(hqe, ['label2']) 245 246 247 def test_get_labels_dependencies(self): 248 hqe = self._create_hqe(dependency_labels=(self.label3,), 249 metahosts=[1]) 250 self._check_hqe_labels(hqe, ['label1', 'label3']) 251 252 253 def setup_abort_test(self, agent_finished=True): 254 """Setup the variables for testing abort method. 255 256 @param agent_finished: True to mock agent is finished before aborting 257 the hqe. 258 @return hqe, dispatcher: Mock object of hqe and dispatcher to be used 259 to test abort method. 260 """ 261 hqe = self._create_hqe(hosts=[1]) 262 hqe.aborted = True 263 hqe.complete = False 264 hqe.status = models.HostQueueEntry.Status.STARTING 265 hqe.started_on = datetime.datetime.now() 266 267 dispatcher = self.god.create_mock_class(monitor_db.BaseDispatcher, 268 'BaseDispatcher') 269 agent = self.god.create_mock_class(monitor_db.Agent, 'Agent') 270 dispatcher.get_agents_for_entry.expect_call(hqe).and_return([agent]) 271 agent.is_done.expect_call().and_return(agent_finished) 272 return hqe, dispatcher 273 274 275 def test_abort_fail_with_unfinished_agent(self): 276 """abort should fail if the hqe still has agent not finished. 277 """ 278 hqe, dispatcher = self.setup_abort_test(agent_finished=False) 279 self.assertIsNone(hqe.finished_on) 280 with self.assertRaises(AssertionError): 281 hqe.abort(dispatcher) 282 self.god.check_playback() 283 # abort failed, finished_on should not be set 284 self.assertIsNone(hqe.finished_on) 285 286 287 def test_abort_success(self): 288 """abort should succeed if all agents for the hqe are finished. 289 """ 290 hqe, dispatcher = self.setup_abort_test(agent_finished=True) 291 self.assertIsNone(hqe.finished_on) 292 hqe.abort(dispatcher) 293 self.god.check_playback() 294 self.assertIsNotNone(hqe.finished_on) 295 296 297 def test_set_finished_on(self): 298 """Test that finished_on is set when hqe completes.""" 299 for status in host_queue_entry_states.Status.values: 300 hqe = self._create_hqe(hosts=[1]) 301 hqe.started_on = datetime.datetime.now() 302 hqe.job.update_field('shard_id', 3) 303 self.assertIsNone(hqe.finished_on) 304 hqe.set_status(status) 305 if status in host_queue_entry_states.COMPLETE_STATUSES: 306 self.assertIsNotNone(hqe.finished_on) 307 self.assertIsNone(hqe.job.shard_id) 308 else: 309 self.assertIsNone(hqe.finished_on) 310 self.assertEquals(hqe.job.shard_id, 3) 311 312 313class JobTest(BaseSchedulerModelsTest): 314 def setUp(self): 315 super(JobTest, self).setUp() 316 317 def _mock_create(**kwargs): 318 task = models.SpecialTask(**kwargs) 319 task.save() 320 self._tasks.append(task) 321 self.god.stub_with(models.SpecialTask.objects, 'create', _mock_create) 322 323 324 def _test_pre_job_tasks_helper(self, 325 reboot_before=model_attributes.RebootBefore.ALWAYS): 326 """ 327 Calls HQE._do_schedule_pre_job_tasks() and returns the created special 328 task 329 """ 330 self._tasks = [] 331 queue_entry = scheduler_models.HostQueueEntry.fetch('id = 1')[0] 332 queue_entry.job.reboot_before = reboot_before 333 queue_entry._do_schedule_pre_job_tasks() 334 return self._tasks 335 336 337 def test_job_request_abort(self): 338 django_job = self._create_job(hosts=[5, 6]) 339 job = scheduler_models.Job(django_job.id) 340 job.request_abort() 341 django_hqes = list(models.HostQueueEntry.objects.filter(job=job.id)) 342 for hqe in django_hqes: 343 self.assertTrue(hqe.aborted) 344 345 346 def _check_special_tasks(self, tasks, task_types): 347 self.assertEquals(len(tasks), len(task_types)) 348 for task, (task_type, queue_entry_id) in zip(tasks, task_types): 349 self.assertEquals(task.task, task_type) 350 self.assertEquals(task.host.id, 1) 351 if queue_entry_id: 352 self.assertEquals(task.queue_entry.id, queue_entry_id) 353 354 355 def test_run_asynchronous(self): 356 self._create_job(hosts=[1, 2]) 357 358 tasks = self._test_pre_job_tasks_helper() 359 360 self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) 361 362 363 def test_run_asynchronous_skip_verify(self): 364 job = self._create_job(hosts=[1, 2]) 365 job.run_verify = False 366 job.save() 367 368 tasks = self._test_pre_job_tasks_helper() 369 370 self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) 371 372 373 def test_run_synchronous_verify(self): 374 self._create_job(hosts=[1, 2], synchronous=True) 375 376 tasks = self._test_pre_job_tasks_helper() 377 378 self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) 379 380 381 def test_run_synchronous_skip_verify(self): 382 job = self._create_job(hosts=[1, 2], synchronous=True) 383 job.run_verify = False 384 job.save() 385 386 tasks = self._test_pre_job_tasks_helper() 387 388 self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) 389 390 391 def test_run_asynchronous_do_not_reset(self): 392 job = self._create_job(hosts=[1, 2]) 393 job.run_reset = False 394 job.run_verify = False 395 job.save() 396 397 tasks = self._test_pre_job_tasks_helper() 398 399 self.assertEquals(tasks, []) 400 401 402 def test_run_synchronous_do_not_reset_no_RebootBefore(self): 403 job = self._create_job(hosts=[1, 2], synchronous=True) 404 job.reboot_before = model_attributes.RebootBefore.NEVER 405 job.save() 406 407 tasks = self._test_pre_job_tasks_helper( 408 reboot_before=model_attributes.RebootBefore.NEVER) 409 410 self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)]) 411 412 413 def test_run_asynchronous_do_not_reset(self): 414 job = self._create_job(hosts=[1, 2], synchronous=False) 415 job.reboot_before = model_attributes.RebootBefore.NEVER 416 job.save() 417 418 tasks = self._test_pre_job_tasks_helper( 419 reboot_before=model_attributes.RebootBefore.NEVER) 420 421 self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)]) 422 423 424 def test_reboot_before_always(self): 425 job = self._create_job(hosts=[1]) 426 job.reboot_before = model_attributes.RebootBefore.ALWAYS 427 job.save() 428 429 tasks = self._test_pre_job_tasks_helper() 430 431 self._check_special_tasks(tasks, [ 432 (models.SpecialTask.Task.RESET, None) 433 ]) 434 435 436 def _test_reboot_before_if_dirty_helper(self): 437 job = self._create_job(hosts=[1]) 438 job.reboot_before = model_attributes.RebootBefore.IF_DIRTY 439 job.save() 440 441 tasks = self._test_pre_job_tasks_helper() 442 task_types = [(models.SpecialTask.Task.RESET, None)] 443 444 self._check_special_tasks(tasks, task_types) 445 446 447 def test_reboot_before_if_dirty(self): 448 models.Host.smart_get(1).update_object(dirty=True) 449 self._test_reboot_before_if_dirty_helper() 450 451 452 def test_reboot_before_not_dirty(self): 453 models.Host.smart_get(1).update_object(dirty=False) 454 self._test_reboot_before_if_dirty_helper() 455 456 457if __name__ == '__main__': 458 unittest.main() 459