1#!/usr/bin/python2
2# Copyright 2016 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6import unittest
7
8import common
9
10from autotest_lib.server import utils
11from autotest_lib.server.hosts import host_info
12from autotest_lib.server.hosts import base_label
13
14 # pylint: disable=missing-docstring
15
16
17class TestBaseLabel(base_label.BaseLabel):
18    """TestBaseLabel is used for testing/validating BaseLabel methods."""
19
20    _NAME = 'base_label'
21
22    def exists(self, host):
23        return host.exists
24
25
26class TestBaseLabels(base_label.BaseLabel):
27    """
28    TestBaseLabels is used for testing/validating BaseLabel methods.
29
30    This is a variation of BaseLabel with multiple labels for _NAME
31    to ensure we handle a label that contains a list of labels for
32    its _NAME attribute.
33    """
34
35    _NAME = ['base_label_1' , 'base_label_2']
36
37
38class TestStringPrefixLabel(base_label.StringPrefixLabel):
39    """
40    TestBaseLabels is used for testing/validating StringPrefixLabel methods.
41
42    This test class is to check that we properly construct the prefix labels
43    with the label passed in during construction.
44    """
45
46    _NAME = 'prefix'
47
48    def __init__(self, label='postfix'):
49        self.label_to_return = label
50
51
52    def generate_labels(self, _):
53        return [self.label_to_return]
54
55
56class MockAFEHost(utils.EmptyAFEHost):
57
58    def __init__(self, labels=None, attributes=None):
59        self.labels = labels or []
60        self.attributes = attributes or {}
61
62
63class MockHost(object):
64
65    def __init__(self, exists=True, store=None):
66        self.hostname = 'hostname'
67        self.exists = exists
68        self.host_info_store = store
69
70
71class BaseLabelUnittests(unittest.TestCase):
72    """Unittest for testing base_label.BaseLabel."""
73
74    def setUp(self):
75        self.test_base_label = TestBaseLabel()
76        self.test_base_labels = TestBaseLabels()
77
78
79    def test_generate_labels(self):
80        """Let's make sure generate_labels() returns the labels expected."""
81        self.assertEqual(self.test_base_label.generate_labels(None),
82                         [self.test_base_label._NAME])
83
84
85    def test_get(self):
86        """Let's make sure the logic in get() works as expected."""
87        # We should get labels here.
88        self.assertEqual(self.test_base_label.get(MockHost(exists=True)),
89                         [self.test_base_label._NAME])
90        # We should get nothing here.
91        self.assertEqual(self.test_base_label.get(MockHost(exists=False)),
92                         [])
93
94
95    def test_get_all_labels(self):
96        """Check that we get the expected labels for get_all_labels()."""
97        prefix_tbl, full_tbl = self.test_base_label.get_all_labels()
98        prefix_tbls, full_tbls = self.test_base_labels.get_all_labels()
99
100        # We want to check that we always get a list of labels regardless if
101        # the label class attribute _NAME is a list or a string.
102        self.assertEqual(full_tbl, set([self.test_base_label._NAME]))
103        self.assertEqual(full_tbls, set(self.test_base_labels._NAME))
104
105        # We want to make sure we get nothing on the prefix_* side of things
106        # since BaseLabel shouldn't be a prefix for any label.
107        self.assertEqual(prefix_tbl, set())
108        self.assertEqual(prefix_tbls, set())
109
110    def test_update_for_task(self):
111        self.assertTrue(self.test_base_label.update_for_task(''))
112
113
114class StringPrefixLabelUnittests(unittest.TestCase):
115    """Unittest for testing base_label.StringPrefixLabel."""
116
117    def setUp(self):
118        self.postfix_label = 'postfix_label'
119        self.test_label = TestStringPrefixLabel(label=self.postfix_label)
120
121
122    def test_get(self):
123        """Let's make sure that the labels we get are prefixed."""
124        self.assertEqual(self.test_label.get(None),
125                         ['%s:%s' % (self.test_label._NAME,
126                                     self.postfix_label)])
127
128
129    def test_get_all_labels(self):
130        """Check that we only get prefix labels and no full labels."""
131        prefix_labels, postfix_labels = self.test_label.get_all_labels()
132        self.assertEqual(prefix_labels, set(['%s:' % self.test_label._NAME]))
133        self.assertEqual(postfix_labels, set())
134
135
136class LabelRetrieverUnittests(unittest.TestCase):
137    """Unittest for testing base_label.LabelRetriever."""
138
139    def setUp(self):
140        label_list = [TestStringPrefixLabel(), TestBaseLabel()]
141        self.retriever = base_label.LabelRetriever(label_list)
142        self.retriever._populate_known_labels(label_list, '')
143
144
145    def test_populate_known_labels(self):
146        """Check that _populate_known_labels() works as expected."""
147        full_names = set([TestBaseLabel._NAME])
148        prefix_names = set(['%s:' % TestStringPrefixLabel._NAME])
149        # Check on a normal retriever.
150        self.assertEqual(self.retriever.label_full_names, full_names)
151        self.assertEqual(self.retriever.label_prefix_names, prefix_names)
152
153
154    def test_is_known_label(self):
155        """Check _is_known_label() detects/skips the right labels."""
156        # This will be a list of tuples of label and expected return bool.
157        # Make sure Full matches match correctly
158        labels_to_check = [(TestBaseLabel._NAME, True),
159                           ('%s:' % TestStringPrefixLabel._NAME, True),
160                           # Make sure partial matches fail.
161                           (TestBaseLabel._NAME[:2], False),
162                           ('%s:' % TestStringPrefixLabel._NAME[:2], False),
163                           ('no_label_match', False)]
164
165        for label, expected_known in labels_to_check:
166            self.assertEqual(self.retriever._is_known_label(label),
167                             expected_known)
168
169
170    def test_update_labels(self):
171        """Check that we add/remove the expected labels in update_labels()."""
172        label_to_add = 'label_to_add'
173        label_to_remove = 'prefix:label_to_remove'
174        store = host_info.InMemoryHostInfoStore(
175                info=host_info.HostInfo(
176                        labels=[label_to_remove, TestBaseLabel._NAME],
177                ),
178        )
179        mockhost = MockHost(store=store)
180
181        retriever = base_label.LabelRetriever(
182                [TestStringPrefixLabel(label=label_to_add),
183                 TestBaseLabel()])
184        retriever.update_labels(mockhost)
185        self.assertEqual(
186                set(store.get().labels),
187                {'%s:%s' % (TestStringPrefixLabel._NAME, label_to_add),
188                 TestBaseLabel._NAME},
189        )
190
191
192if __name__ == '__main__':
193    unittest.main()
194
195