1# Copyright 2016 The Chromium OS Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5import cStringIO 6import inspect 7import json 8import unittest 9 10import common 11from autotest_lib.server.hosts import host_info 12 13 14class HostInfoTest(unittest.TestCase): 15 """Tests the non-trivial attributes of HostInfo.""" 16 17 def setUp(self): 18 self.info = host_info.HostInfo() 19 20 def test_info_comparison_to_wrong_type(self): 21 """Comparing HostInfo to a different type always returns False.""" 22 self.assertNotEqual(host_info.HostInfo(), 42) 23 self.assertNotEqual(host_info.HostInfo(), None) 24 # equality and non-equality are unrelated by the data model. 25 self.assertFalse(host_info.HostInfo() == 42) 26 self.assertFalse(host_info.HostInfo() == None) 27 28 29 def test_empty_infos_are_equal(self): 30 """Tests that empty HostInfo objects are considered equal.""" 31 self.assertEqual(host_info.HostInfo(), host_info.HostInfo()) 32 # equality and non-equality are unrelated by the data model. 33 self.assertFalse(host_info.HostInfo() != host_info.HostInfo()) 34 35 36 def test_non_trivial_infos_are_equal(self): 37 """Tests that the most complicated infos are correctly stated equal.""" 38 info1 = host_info.HostInfo( 39 labels=['label1', 'label2', 'label1'], 40 attributes={'attrib1': None, 'attrib2': 'val2'}) 41 info2 = host_info.HostInfo( 42 labels=['label1', 'label2', 'label1'], 43 attributes={'attrib1': None, 'attrib2': 'val2'}) 44 self.assertEqual(info1, info2) 45 # equality and non-equality are unrelated by the data model. 46 self.assertFalse(info1 != info2) 47 48 49 def test_non_equal_infos(self): 50 """Tests that HostInfo objects with different information are unequal""" 51 info1 = host_info.HostInfo(labels=['label']) 52 info2 = host_info.HostInfo(attributes={'attrib': 'value'}) 53 self.assertNotEqual(info1, info2) 54 # equality and non-equality are unrelated by the data model. 55 self.assertFalse(info1 == info2) 56 57 58 def test_build_needs_prefix(self): 59 """The build prefix is of the form '<type>-version:'""" 60 self.info.labels = ['cros-version', 'fwrw-version', 'fwro-version'] 61 self.assertIsNone(self.info.build) 62 63 64 def test_build_prefix_must_be_anchored(self): 65 """Ensure that build ignores prefixes occuring mid-string.""" 66 self.info.labels = ['not-at-start-cros-version:cros1'] 67 self.assertIsNone(self.info.build) 68 69 70 def test_build_ignores_firmware(self): 71 """build attribute should ignore firmware versions.""" 72 self.info.labels = ['fwrw-version:fwrw1', 'fwro-version:fwro1'] 73 self.assertIsNone(self.info.build) 74 75 76 def test_build_returns_first_match(self): 77 """When multiple labels match, first one should be used as build.""" 78 self.info.labels = ['cros-version:cros1', 'cros-version:cros2'] 79 self.assertEqual(self.info.build, 'cros1') 80 81 82 def test_build_prefer_cros_over_others(self): 83 """When multiple versions are available, prefer cros.""" 84 self.info.labels = ['cheets-version:ab1', 'cros-version:cros1'] 85 self.assertEqual(self.info.build, 'cros1') 86 self.info.labels = ['cros-version:cros1', 'cheets-version:ab1'] 87 self.assertEqual(self.info.build, 'cros1') 88 89 90 def test_os_no_match(self): 91 """Use proper prefix to search for os information.""" 92 self.info.labels = ['something_else', 'cros-version:hana', 93 'os_without_colon'] 94 self.assertEqual(self.info.os, '') 95 96 97 def test_os_returns_first_match(self): 98 """Return the first matching os label.""" 99 self.info.labels = ['os:linux', 'os:windows', 'os_corrupted_label'] 100 self.assertEqual(self.info.os, 'linux') 101 102 103 def test_board_no_match(self): 104 """Use proper prefix to search for board information.""" 105 self.info.labels = ['something_else', 'cros-version:hana', 'os:blah', 106 'board_my_board_no_colon'] 107 self.assertEqual(self.info.board, '') 108 109 110 def test_board_returns_first_match(self): 111 """Return the first matching board label.""" 112 self.info.labels = ['board_corrupted', 'board:walk', 'board:bored'] 113 self.assertEqual(self.info.board, 'walk') 114 115 116 def test_pools_no_match(self): 117 """Use proper prefix to search for pool information.""" 118 self.info.labels = ['something_else', 'cros-version:hana', 'os:blah', 119 'board_my_board_no_colon', 'board:my_board'] 120 self.assertEqual(self.info.pools, set()) 121 122 123 def test_pools_returns_all_matches(self): 124 """Return all matching pool labels.""" 125 self.info.labels = ['board_corrupted', 'board:walk', 'board:bored', 126 'pool:first_pool', 'pool:second_pool'] 127 self.assertEqual(self.info.pools, {'second_pool', 'first_pool'}) 128 129 130 def test_str(self): 131 """Sanity checks the __str__ implementation.""" 132 info = host_info.HostInfo(labels=['a'], attributes={'b': 2}) 133 self.assertEqual(str(info), 134 "HostInfo[Labels: ['a'], Attributes: {'b': 2}]") 135 136 137 def test_clear_version_labels_no_labels(self): 138 """When no version labels exist, do nothing for clear_version_labels.""" 139 original_labels = ['board:something', 'os:something_else', 140 'pool:mypool', 'cheets-version-corrupted:blah', 141 'cros-version'] 142 self.info.labels = list(original_labels) 143 self.info.clear_version_labels() 144 self.assertListEqual(self.info.labels, original_labels) 145 146 147 def test_clear_all_version_labels(self): 148 """Clear each recognized type of version label.""" 149 original_labels = ['extra_label', 'cros-version:cr1', 150 'cheets-version:ab1'] 151 self.info.labels = list(original_labels) 152 self.info.clear_version_labels() 153 self.assertListEqual(self.info.labels, ['extra_label']) 154 155 def test_clear_all_version_label_prefixes(self): 156 """Clear each recognized type of version label with empty value.""" 157 original_labels = ['extra_label', 'cros-version:', 'cheets-version:'] 158 self.info.labels = list(original_labels) 159 self.info.clear_version_labels() 160 self.assertListEqual(self.info.labels, ['extra_label']) 161 162 163 def test_set_version_labels_updates_in_place(self): 164 """Update version label in place if prefix already exists.""" 165 self.info.labels = ['extra', 'cros-version:X', 'cheets-version:Y'] 166 self.info.set_version_label('cros-version', 'Z') 167 self.assertListEqual(self.info.labels, ['extra', 'cros-version:Z', 168 'cheets-version:Y']) 169 170 def test_set_version_labels_appends(self): 171 """Append a new version label if the prefix doesn't exist.""" 172 self.info.labels = ['extra', 'cheets-version:Y'] 173 self.info.set_version_label('cros-version', 'Z') 174 self.assertListEqual(self.info.labels, ['extra', 'cheets-version:Y', 175 'cros-version:Z']) 176 177 178class InMemoryHostInfoStoreTest(unittest.TestCase): 179 """Basic tests for CachingHostInfoStore using InMemoryHostInfoStore.""" 180 181 def setUp(self): 182 self.store = host_info.InMemoryHostInfoStore() 183 184 185 def _verify_host_info_data(self, host_info, labels, attributes): 186 """Verifies the data in the given host_info.""" 187 self.assertListEqual(host_info.labels, labels) 188 self.assertDictEqual(host_info.attributes, attributes) 189 190 191 def test_first_get_refreshes_cache(self): 192 """Test that the first call to get gets the data from store.""" 193 self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 194 got = self.store.get() 195 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 196 197 198 def test_repeated_get_returns_from_cache(self): 199 """Tests that repeated calls to get do not refresh cache.""" 200 self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 201 got = self.store.get() 202 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 203 204 self.store.info = host_info.HostInfo(['label1', 'label2'], {}) 205 got = self.store.get() 206 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 207 208 209 def test_get_uncached_always_refreshes_cache(self): 210 """Tests that calling get_uncached always refreshes the cache.""" 211 self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 212 got = self.store.get(force_refresh=True) 213 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 214 215 self.store.info = host_info.HostInfo(['label1', 'label2'], {}) 216 got = self.store.get(force_refresh=True) 217 self._verify_host_info_data(got, ['label1', 'label2'], {}) 218 219 220 def test_commit(self): 221 """Test that commit sends data to store.""" 222 info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 223 self._verify_host_info_data(self.store.info, [], {}) 224 self.store.commit(info) 225 self._verify_host_info_data(self.store.info, ['label1'], 226 {'attrib1': 'val1'}) 227 228 229 def test_commit_then_get(self): 230 """Test a commit-get roundtrip.""" 231 got = self.store.get() 232 self._verify_host_info_data(got, [], {}) 233 234 info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 235 self.store.commit(info) 236 got = self.store.get() 237 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 238 239 240 def test_commit_then_get_uncached(self): 241 """Test a commit-get_uncached roundtrip.""" 242 got = self.store.get() 243 self._verify_host_info_data(got, [], {}) 244 245 info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 246 self.store.commit(info) 247 got = self.store.get(force_refresh=True) 248 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 249 250 251 def test_commit_deepcopies_data(self): 252 """Once commited, changes to HostInfo don't corrupt the store.""" 253 info = host_info.HostInfo(['label1'], {'attrib1': {'key1': 'data1'}}) 254 self.store.commit(info) 255 info.labels.append('label2') 256 info.attributes['attrib1']['key1'] = 'data2' 257 self._verify_host_info_data(self.store.info, 258 ['label1'], {'attrib1': {'key1': 'data1'}}) 259 260 261 def test_get_returns_deepcopy(self): 262 """The cached object is protected from |get| caller modifications.""" 263 self.store.info = host_info.HostInfo(['label1'], 264 {'attrib1': {'key1': 'data1'}}) 265 got = self.store.get() 266 self._verify_host_info_data(got, 267 ['label1'], {'attrib1': {'key1': 'data1'}}) 268 got.labels.append('label2') 269 got.attributes['attrib1']['key1'] = 'data2' 270 got = self.store.get() 271 self._verify_host_info_data(got, 272 ['label1'], {'attrib1': {'key1': 'data1'}}) 273 274 275 def test_str(self): 276 """Sanity tests __str__ implementation.""" 277 self.store.info = host_info.HostInfo(['label1'], 278 {'attrib1': {'key1': 'data1'}}) 279 self.assertEqual(str(self.store), 280 'InMemoryHostInfoStore[%s]' % self.store.info) 281 282 283class ExceptionRaisingStore(host_info.CachingHostInfoStore): 284 """A test class that always raises on refresh / commit.""" 285 286 def __init__(self): 287 super(ExceptionRaisingStore, self).__init__() 288 self.refresh_raises = True 289 self.commit_raises = True 290 291 292 def _refresh_impl(self): 293 if self.refresh_raises: 294 raise host_info.StoreError('no can do') 295 return host_info.HostInfo() 296 297 def _commit_impl(self, _): 298 if self.commit_raises: 299 raise host_info.StoreError('wont wont wont') 300 301 302class CachingHostInfoStoreErrorTest(unittest.TestCase): 303 """Tests error behaviours of CachingHostInfoStore.""" 304 305 def setUp(self): 306 self.store = ExceptionRaisingStore() 307 308 309 def test_failed_refresh_cleans_cache(self): 310 """Sanity checks return values when refresh raises.""" 311 with self.assertRaises(host_info.StoreError): 312 self.store.get() 313 # Since |get| hit an error, a subsequent get should again hit the store. 314 with self.assertRaises(host_info.StoreError): 315 self.store.get() 316 317 318 def test_failed_commit_cleans_cache(self): 319 """Check that a failed commit cleanes cache.""" 320 # Let's initialize the store without errors. 321 self.store.refresh_raises = False 322 self.store.get(force_refresh=True) 323 self.store.refresh_raises = True 324 325 with self.assertRaises(host_info.StoreError): 326 self.store.commit(host_info.HostInfo()) 327 # Since |commit| hit an error, a subsequent get should again hit the 328 # store. 329 with self.assertRaises(host_info.StoreError): 330 self.store.get() 331 332 333class GetStoreFromMachineTest(unittest.TestCase): 334 """Tests the get_store_from_machine function.""" 335 336 def test_machine_is_dict(self): 337 """We extract the store when machine is a dict.""" 338 machine = { 339 'something': 'else', 340 'host_info_store': 5 341 } 342 self.assertEqual(host_info.get_store_from_machine(machine), 5) 343 344 345 def test_machine_is_string(self): 346 """We return a trivial store when machine is a string.""" 347 machine = 'hostname' 348 self.assertTrue(isinstance(host_info.get_store_from_machine(machine), 349 host_info.InMemoryHostInfoStore)) 350 351 352class HostInfoJsonSerializationTestCase(unittest.TestCase): 353 """Tests the json_serialize and json_deserialize functions.""" 354 355 CURRENT_SERIALIZATION_VERSION = host_info._CURRENT_SERIALIZATION_VERSION 356 357 def test_serialize_empty(self): 358 """Serializing empty HostInfo results in the expected json.""" 359 info = host_info.HostInfo() 360 file_obj = cStringIO.StringIO() 361 host_info.json_serialize(info, file_obj) 362 file_obj.seek(0) 363 expected_dict = { 364 'serializer_version': self.CURRENT_SERIALIZATION_VERSION, 365 'attributes' : {}, 366 'labels': [], 367 } 368 self.assertEqual(json.load(file_obj), expected_dict) 369 370 371 def test_serialize_non_empty(self): 372 """Serializing a populated HostInfo results in expected json.""" 373 info = host_info.HostInfo(labels=['label1'], 374 attributes={'attrib': 'val'}) 375 file_obj = cStringIO.StringIO() 376 host_info.json_serialize(info, file_obj) 377 file_obj.seek(0) 378 expected_dict = { 379 'serializer_version': self.CURRENT_SERIALIZATION_VERSION, 380 'attributes' : {'attrib': 'val'}, 381 'labels': ['label1'], 382 } 383 self.assertEqual(json.load(file_obj), expected_dict) 384 385 386 def test_round_trip_empty(self): 387 """Serializing - deserializing empty HostInfo keeps it unchanged.""" 388 info = host_info.HostInfo() 389 serialized_fp = cStringIO.StringIO() 390 host_info.json_serialize(info, serialized_fp) 391 serialized_fp.seek(0) 392 got = host_info.json_deserialize(serialized_fp) 393 self.assertEqual(got, info) 394 395 396 def test_round_trip_non_empty(self): 397 """Serializing - deserializing non-empty HostInfo keeps it unchanged.""" 398 info = host_info.HostInfo( 399 labels=['label1'], 400 attributes = {'attrib': 'val'}) 401 serialized_fp = cStringIO.StringIO() 402 host_info.json_serialize(info, serialized_fp) 403 serialized_fp.seek(0) 404 got = host_info.json_deserialize(serialized_fp) 405 self.assertEqual(got, info) 406 407 408 def test_deserialize_malformed_json_raises(self): 409 """Deserializing a malformed string raises.""" 410 with self.assertRaises(host_info.DeserializationError): 411 host_info.json_deserialize(cStringIO.StringIO('{labels:[')) 412 413 414 def test_deserialize_no_version_raises(self): 415 """Deserializing a string with no serializer version raises.""" 416 info = host_info.HostInfo() 417 serialized_fp = cStringIO.StringIO() 418 host_info.json_serialize(info, serialized_fp) 419 serialized_fp.seek(0) 420 421 serialized_dict = json.load(serialized_fp) 422 del serialized_dict['serializer_version'] 423 serialized_no_version_str = json.dumps(serialized_dict) 424 425 with self.assertRaises(host_info.DeserializationError): 426 host_info.json_deserialize( 427 cStringIO.StringIO(serialized_no_version_str)) 428 429 430 def test_deserialize_malformed_host_info_raises(self): 431 """Deserializing a malformed host_info raises.""" 432 info = host_info.HostInfo() 433 serialized_fp = cStringIO.StringIO() 434 host_info.json_serialize(info, serialized_fp) 435 serialized_fp.seek(0) 436 437 serialized_dict = json.load(serialized_fp) 438 del serialized_dict['labels'] 439 serialized_no_version_str = json.dumps(serialized_dict) 440 441 with self.assertRaises(host_info.DeserializationError): 442 host_info.json_deserialize( 443 cStringIO.StringIO(serialized_no_version_str)) 444 445 446 def test_enforce_compatibility_version_1(self): 447 """Tests that required fields are never dropped. 448 449 Never change this test. If you must break compatibility, uprev the 450 serializer version and add a new test for the newer version. 451 452 Adding a field to compat_info_str means we're making the new field 453 mandatory. This breaks backwards compatibility. 454 Removing a field from compat_info_str means we're no longer requiring a 455 field to be mandatory. This breaks forwards compatibility. 456 """ 457 compat_dict = { 458 'serializer_version': 1, 459 'attributes': {}, 460 'labels': [] 461 } 462 serialized_str = json.dumps(compat_dict) 463 serialized_fp = cStringIO.StringIO(serialized_str) 464 host_info.json_deserialize(serialized_fp) 465 466 467 def test_serialize_pretty_print(self): 468 """Serializing a host_info dumps the json in human-friendly format""" 469 info = host_info.HostInfo(labels=['label1'], 470 attributes={'attrib': 'val'}) 471 serialized_fp = cStringIO.StringIO() 472 host_info.json_serialize(info, serialized_fp) 473 expected = """{ 474 "attributes": { 475 "attrib": "val" 476 }, 477 "labels": [ 478 "label1" 479 ], 480 "serializer_version": %d 481 }""" % self.CURRENT_SERIALIZATION_VERSION 482 self.assertEqual(serialized_fp.getvalue(), inspect.cleandoc(expected)) 483 484 485if __name__ == '__main__': 486 unittest.main() 487