1# Copyright 2016 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import cStringIO
6import inspect
7import json
8import unittest
9
10import common
11from autotest_lib.server.hosts import host_info
12
13
14class HostInfoTest(unittest.TestCase):
15    """Tests the non-trivial attributes of HostInfo."""
16
17    def setUp(self):
18        self.info = host_info.HostInfo()
19
20    def test_info_comparison_to_wrong_type(self):
21        """Comparing HostInfo to a different type always returns False."""
22        self.assertNotEqual(host_info.HostInfo(), 42)
23        self.assertNotEqual(host_info.HostInfo(), None)
24        # equality and non-equality are unrelated by the data model.
25        self.assertFalse(host_info.HostInfo() == 42)
26        self.assertFalse(host_info.HostInfo() == None)
27
28
29    def test_empty_infos_are_equal(self):
30        """Tests that empty HostInfo objects are considered equal."""
31        self.assertEqual(host_info.HostInfo(), host_info.HostInfo())
32        # equality and non-equality are unrelated by the data model.
33        self.assertFalse(host_info.HostInfo() != host_info.HostInfo())
34
35
36    def test_non_trivial_infos_are_equal(self):
37        """Tests that the most complicated infos are correctly stated equal."""
38        info1 = host_info.HostInfo(
39                labels=['label1', 'label2', 'label1'],
40                attributes={'attrib1': None, 'attrib2': 'val2'})
41        info2 = host_info.HostInfo(
42                labels=['label1', 'label2', 'label1'],
43                attributes={'attrib1': None, 'attrib2': 'val2'})
44        self.assertEqual(info1, info2)
45        # equality and non-equality are unrelated by the data model.
46        self.assertFalse(info1 != info2)
47
48
49    def test_non_equal_infos(self):
50        """Tests that HostInfo objects with different information are unequal"""
51        info1 = host_info.HostInfo(labels=['label'])
52        info2 = host_info.HostInfo(attributes={'attrib': 'value'})
53        self.assertNotEqual(info1, info2)
54        # equality and non-equality are unrelated by the data model.
55        self.assertFalse(info1 == info2)
56
57
58    def test_build_needs_prefix(self):
59        """The build prefix is of the form '<type>-version:'"""
60        self.info.labels = ['cros-version', 'fwrw-version', 'fwro-version']
61        self.assertIsNone(self.info.build)
62
63
64    def test_build_prefix_must_be_anchored(self):
65        """Ensure that build ignores prefixes occuring mid-string."""
66        self.info.labels = ['not-at-start-cros-version:cros1']
67        self.assertIsNone(self.info.build)
68
69
70    def test_build_ignores_firmware(self):
71        """build attribute should ignore firmware versions."""
72        self.info.labels = ['fwrw-version:fwrw1', 'fwro-version:fwro1']
73        self.assertIsNone(self.info.build)
74
75
76    def test_build_returns_first_match(self):
77        """When multiple labels match, first one should be used as build."""
78        self.info.labels = ['cros-version:cros1', 'cros-version:cros2']
79        self.assertEqual(self.info.build, 'cros1')
80
81
82    def test_build_prefer_cros_over_others(self):
83        """When multiple versions are available, prefer cros."""
84        self.info.labels = ['cheets-version:ab1', 'cros-version:cros1']
85        self.assertEqual(self.info.build, 'cros1')
86        self.info.labels = ['cros-version:cros1', 'cheets-version:ab1']
87        self.assertEqual(self.info.build, 'cros1')
88
89
90    def test_os_no_match(self):
91        """Use proper prefix to search for os information."""
92        self.info.labels = ['something_else', 'cros-version:hana',
93                            'os_without_colon']
94        self.assertEqual(self.info.os, '')
95
96
97    def test_os_returns_first_match(self):
98        """Return the first matching os label."""
99        self.info.labels = ['os:linux', 'os:windows', 'os_corrupted_label']
100        self.assertEqual(self.info.os, 'linux')
101
102
103    def test_board_no_match(self):
104        """Use proper prefix to search for board information."""
105        self.info.labels = ['something_else', 'cros-version:hana', 'os:blah',
106                            'board_my_board_no_colon']
107        self.assertEqual(self.info.board, '')
108
109
110    def test_board_returns_first_match(self):
111        """Return the first matching board label."""
112        self.info.labels = ['board_corrupted', 'board:walk', 'board:bored']
113        self.assertEqual(self.info.board, 'walk')
114
115
116    def test_pools_no_match(self):
117        """Use proper prefix to search for pool information."""
118        self.info.labels = ['something_else', 'cros-version:hana', 'os:blah',
119                            'board_my_board_no_colon', 'board:my_board']
120        self.assertEqual(self.info.pools, set())
121
122
123    def test_pools_returns_all_matches(self):
124        """Return all matching pool labels."""
125        self.info.labels = ['board_corrupted', 'board:walk', 'board:bored',
126                            'pool:first_pool', 'pool:second_pool']
127        self.assertEqual(self.info.pools, {'second_pool', 'first_pool'})
128
129
130    def test_str(self):
131        """Sanity checks the __str__ implementation."""
132        info = host_info.HostInfo(labels=['a'], attributes={'b': 2})
133        self.assertEqual(str(info),
134                         "HostInfo[Labels: ['a'], Attributes: {'b': 2}]")
135
136
137    def test_clear_version_labels_no_labels(self):
138        """When no version labels exist, do nothing for clear_version_labels."""
139        original_labels = ['board:something', 'os:something_else',
140                           'pool:mypool', 'cheets-version-corrupted:blah',
141                           'cros-version']
142        self.info.labels = list(original_labels)
143        self.info.clear_version_labels()
144        self.assertListEqual(self.info.labels, original_labels)
145
146
147    def test_clear_all_version_labels(self):
148        """Clear each recognized type of version label."""
149        original_labels = ['extra_label', 'cros-version:cr1',
150                           'cheets-version:ab1']
151        self.info.labels = list(original_labels)
152        self.info.clear_version_labels()
153        self.assertListEqual(self.info.labels, ['extra_label'])
154
155    def test_clear_all_version_label_prefixes(self):
156        """Clear each recognized type of version label with empty value."""
157        original_labels = ['extra_label', 'cros-version:', 'cheets-version:']
158        self.info.labels = list(original_labels)
159        self.info.clear_version_labels()
160        self.assertListEqual(self.info.labels, ['extra_label'])
161
162
163    def test_set_version_labels_updates_in_place(self):
164        """Update version label in place if prefix already exists."""
165        self.info.labels = ['extra', 'cros-version:X', 'cheets-version:Y']
166        self.info.set_version_label('cros-version', 'Z')
167        self.assertListEqual(self.info.labels, ['extra', 'cros-version:Z',
168                                                'cheets-version:Y'])
169
170    def test_set_version_labels_appends(self):
171        """Append a new version label if the prefix doesn't exist."""
172        self.info.labels = ['extra', 'cheets-version:Y']
173        self.info.set_version_label('cros-version', 'Z')
174        self.assertListEqual(self.info.labels, ['extra', 'cheets-version:Y',
175                                                'cros-version:Z'])
176
177
178class InMemoryHostInfoStoreTest(unittest.TestCase):
179    """Basic tests for CachingHostInfoStore using InMemoryHostInfoStore."""
180
181    def setUp(self):
182        self.store = host_info.InMemoryHostInfoStore()
183
184
185    def _verify_host_info_data(self, host_info, labels, attributes):
186        """Verifies the data in the given host_info."""
187        self.assertListEqual(host_info.labels, labels)
188        self.assertDictEqual(host_info.attributes, attributes)
189
190
191    def test_first_get_refreshes_cache(self):
192        """Test that the first call to get gets the data from store."""
193        self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
194        got = self.store.get()
195        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
196
197
198    def test_repeated_get_returns_from_cache(self):
199        """Tests that repeated calls to get do not refresh cache."""
200        self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
201        got = self.store.get()
202        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
203
204        self.store.info = host_info.HostInfo(['label1', 'label2'], {})
205        got = self.store.get()
206        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
207
208
209    def test_get_uncached_always_refreshes_cache(self):
210        """Tests that calling get_uncached always refreshes the cache."""
211        self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
212        got = self.store.get(force_refresh=True)
213        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
214
215        self.store.info = host_info.HostInfo(['label1', 'label2'], {})
216        got = self.store.get(force_refresh=True)
217        self._verify_host_info_data(got, ['label1', 'label2'], {})
218
219
220    def test_commit(self):
221        """Test that commit sends data to store."""
222        info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
223        self._verify_host_info_data(self.store.info, [], {})
224        self.store.commit(info)
225        self._verify_host_info_data(self.store.info, ['label1'],
226                                    {'attrib1': 'val1'})
227
228
229    def test_commit_then_get(self):
230        """Test a commit-get roundtrip."""
231        got = self.store.get()
232        self._verify_host_info_data(got, [], {})
233
234        info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
235        self.store.commit(info)
236        got = self.store.get()
237        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
238
239
240    def test_commit_then_get_uncached(self):
241        """Test a commit-get_uncached roundtrip."""
242        got = self.store.get()
243        self._verify_host_info_data(got, [], {})
244
245        info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
246        self.store.commit(info)
247        got = self.store.get(force_refresh=True)
248        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
249
250
251    def test_commit_deepcopies_data(self):
252        """Once commited, changes to HostInfo don't corrupt the store."""
253        info = host_info.HostInfo(['label1'], {'attrib1': {'key1': 'data1'}})
254        self.store.commit(info)
255        info.labels.append('label2')
256        info.attributes['attrib1']['key1'] = 'data2'
257        self._verify_host_info_data(self.store.info,
258                                    ['label1'], {'attrib1': {'key1': 'data1'}})
259
260
261    def test_get_returns_deepcopy(self):
262        """The cached object is protected from |get| caller modifications."""
263        self.store.info = host_info.HostInfo(['label1'],
264                                             {'attrib1': {'key1': 'data1'}})
265        got = self.store.get()
266        self._verify_host_info_data(got,
267                                    ['label1'], {'attrib1': {'key1': 'data1'}})
268        got.labels.append('label2')
269        got.attributes['attrib1']['key1'] = 'data2'
270        got = self.store.get()
271        self._verify_host_info_data(got,
272                                    ['label1'], {'attrib1': {'key1': 'data1'}})
273
274
275    def test_str(self):
276        """Sanity tests __str__ implementation."""
277        self.store.info = host_info.HostInfo(['label1'],
278                                             {'attrib1': {'key1': 'data1'}})
279        self.assertEqual(str(self.store),
280                         'InMemoryHostInfoStore[%s]' % self.store.info)
281
282
283class ExceptionRaisingStore(host_info.CachingHostInfoStore):
284    """A test class that always raises on refresh / commit."""
285
286    def __init__(self):
287        super(ExceptionRaisingStore, self).__init__()
288        self.refresh_raises = True
289        self.commit_raises = True
290
291
292    def _refresh_impl(self):
293        if self.refresh_raises:
294            raise host_info.StoreError('no can do')
295        return host_info.HostInfo()
296
297    def _commit_impl(self, _):
298        if self.commit_raises:
299            raise host_info.StoreError('wont wont wont')
300
301
302class CachingHostInfoStoreErrorTest(unittest.TestCase):
303    """Tests error behaviours of CachingHostInfoStore."""
304
305    def setUp(self):
306        self.store = ExceptionRaisingStore()
307
308
309    def test_failed_refresh_cleans_cache(self):
310        """Sanity checks return values when refresh raises."""
311        with self.assertRaises(host_info.StoreError):
312            self.store.get()
313        # Since |get| hit an error, a subsequent get should again hit the store.
314        with self.assertRaises(host_info.StoreError):
315            self.store.get()
316
317
318    def test_failed_commit_cleans_cache(self):
319        """Check that a failed commit cleanes cache."""
320        # Let's initialize the store without errors.
321        self.store.refresh_raises = False
322        self.store.get(force_refresh=True)
323        self.store.refresh_raises = True
324
325        with self.assertRaises(host_info.StoreError):
326            self.store.commit(host_info.HostInfo())
327        # Since |commit| hit an error, a subsequent get should again hit the
328        # store.
329        with self.assertRaises(host_info.StoreError):
330            self.store.get()
331
332
333class GetStoreFromMachineTest(unittest.TestCase):
334    """Tests the get_store_from_machine function."""
335
336    def test_machine_is_dict(self):
337        """We extract the store when machine is a dict."""
338        machine = {
339                'something': 'else',
340                'host_info_store': 5
341        }
342        self.assertEqual(host_info.get_store_from_machine(machine), 5)
343
344
345    def test_machine_is_string(self):
346        """We return a trivial store when machine is a string."""
347        machine = 'hostname'
348        self.assertTrue(isinstance(host_info.get_store_from_machine(machine),
349                                   host_info.InMemoryHostInfoStore))
350
351
352class HostInfoJsonSerializationTestCase(unittest.TestCase):
353    """Tests the json_serialize and json_deserialize functions."""
354
355    CURRENT_SERIALIZATION_VERSION = host_info._CURRENT_SERIALIZATION_VERSION
356
357    def test_serialize_empty(self):
358        """Serializing empty HostInfo results in the expected json."""
359        info = host_info.HostInfo()
360        file_obj = cStringIO.StringIO()
361        host_info.json_serialize(info, file_obj)
362        file_obj.seek(0)
363        expected_dict = {
364                'serializer_version': self.CURRENT_SERIALIZATION_VERSION,
365                'attributes' : {},
366                'labels': [],
367        }
368        self.assertEqual(json.load(file_obj), expected_dict)
369
370
371    def test_serialize_non_empty(self):
372        """Serializing a populated HostInfo results in expected json."""
373        info = host_info.HostInfo(labels=['label1'],
374                                  attributes={'attrib': 'val'})
375        file_obj = cStringIO.StringIO()
376        host_info.json_serialize(info, file_obj)
377        file_obj.seek(0)
378        expected_dict = {
379                'serializer_version': self.CURRENT_SERIALIZATION_VERSION,
380                'attributes' : {'attrib': 'val'},
381                'labels': ['label1'],
382        }
383        self.assertEqual(json.load(file_obj), expected_dict)
384
385
386    def test_round_trip_empty(self):
387        """Serializing - deserializing empty HostInfo keeps it unchanged."""
388        info = host_info.HostInfo()
389        serialized_fp = cStringIO.StringIO()
390        host_info.json_serialize(info, serialized_fp)
391        serialized_fp.seek(0)
392        got = host_info.json_deserialize(serialized_fp)
393        self.assertEqual(got, info)
394
395
396    def test_round_trip_non_empty(self):
397        """Serializing - deserializing non-empty HostInfo keeps it unchanged."""
398        info = host_info.HostInfo(
399                labels=['label1'],
400                attributes = {'attrib': 'val'})
401        serialized_fp = cStringIO.StringIO()
402        host_info.json_serialize(info, serialized_fp)
403        serialized_fp.seek(0)
404        got = host_info.json_deserialize(serialized_fp)
405        self.assertEqual(got, info)
406
407
408    def test_deserialize_malformed_json_raises(self):
409        """Deserializing a malformed string raises."""
410        with self.assertRaises(host_info.DeserializationError):
411            host_info.json_deserialize(cStringIO.StringIO('{labels:['))
412
413
414    def test_deserialize_no_version_raises(self):
415        """Deserializing a string with no serializer version raises."""
416        info = host_info.HostInfo()
417        serialized_fp = cStringIO.StringIO()
418        host_info.json_serialize(info, serialized_fp)
419        serialized_fp.seek(0)
420
421        serialized_dict = json.load(serialized_fp)
422        del serialized_dict['serializer_version']
423        serialized_no_version_str = json.dumps(serialized_dict)
424
425        with self.assertRaises(host_info.DeserializationError):
426            host_info.json_deserialize(
427                    cStringIO.StringIO(serialized_no_version_str))
428
429
430    def test_deserialize_malformed_host_info_raises(self):
431        """Deserializing a malformed host_info raises."""
432        info = host_info.HostInfo()
433        serialized_fp = cStringIO.StringIO()
434        host_info.json_serialize(info, serialized_fp)
435        serialized_fp.seek(0)
436
437        serialized_dict = json.load(serialized_fp)
438        del serialized_dict['labels']
439        serialized_no_version_str = json.dumps(serialized_dict)
440
441        with self.assertRaises(host_info.DeserializationError):
442            host_info.json_deserialize(
443                    cStringIO.StringIO(serialized_no_version_str))
444
445
446    def test_enforce_compatibility_version_1(self):
447        """Tests that required fields are never dropped.
448
449        Never change this test. If you must break compatibility, uprev the
450        serializer version and add a new test for the newer version.
451
452        Adding a field to compat_info_str means we're making the new field
453        mandatory. This breaks backwards compatibility.
454        Removing a field from compat_info_str means we're no longer requiring a
455        field to be mandatory. This breaks forwards compatibility.
456        """
457        compat_dict = {
458                'serializer_version': 1,
459                'attributes': {},
460                'labels': []
461        }
462        serialized_str = json.dumps(compat_dict)
463        serialized_fp = cStringIO.StringIO(serialized_str)
464        host_info.json_deserialize(serialized_fp)
465
466
467    def test_serialize_pretty_print(self):
468        """Serializing a host_info dumps the json in human-friendly format"""
469        info = host_info.HostInfo(labels=['label1'],
470                                  attributes={'attrib': 'val'})
471        serialized_fp = cStringIO.StringIO()
472        host_info.json_serialize(info, serialized_fp)
473        expected = """{
474            "attributes": {
475                "attrib": "val"
476            },
477            "labels": [
478                "label1"
479            ],
480            "serializer_version": %d
481        }""" % self.CURRENT_SERIALIZATION_VERSION
482        self.assertEqual(serialized_fp.getvalue(), inspect.cleandoc(expected))
483
484
485if __name__ == '__main__':
486    unittest.main()
487