1# Copyright 2016 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""This class defines the Base Label classes."""
6
7
8import logging
9
10import common
11from autotest_lib.server.hosts import afe_store
12from autotest_lib.server.hosts import host_info
13from autotest_lib.server.hosts import shadowing_store
14
15
16def forever_exists_decorate(exists):
17    """
18    Decorator for labels that should exist forever once applied.
19
20    We'll check if the label already exists on the host and return True if so.
21    Otherwise we'll check if the label should exist on the host.
22
23    @param exists: The exists method on the label class.
24    """
25    def exists_wrapper(self, host):
26        """
27        Wrapper around the label exists method.
28
29        @param self: The label object.
30        @param host: The host object to run methods on.
31
32        @returns True if the label already exists on the host, otherwise run
33            the exists method.
34        """
35        info = host.host_info_store.get()
36        return (self._NAME in info.labels) or exists(self, host)
37    return exists_wrapper
38
39
40class BaseLabel(object):
41    """
42    This class contains the scaffolding for the host-specific labels.
43
44    @property _NAME String that is either the label returned or a prefix of a
45                    generated label.
46    """
47
48    _NAME = None
49
50    def generate_labels(self, host):
51        """
52        Return the list of labels generated for the host.
53
54        @param host: The host object to check on.  Not needed here for base case
55                     but could be needed for subclasses.
56
57        @return a list of labels applicable to the host.
58        """
59        return [self._NAME]
60
61
62    def exists(self, host):
63        """
64        Checks the host if the label is applicable or not.
65
66        This method is geared for the type of labels that indicate if the host
67        has a feature (bluetooth, touchscreen, etc) and as such require
68        detection logic to determine if the label should be applicable to the
69        host or not.
70
71        @param host: The host object to check on.
72        """
73        raise NotImplementedError('exists not implemented')
74
75
76    def get(self, host):
77        """
78        Return the list of labels.
79
80        @param host: The host object to check on.
81        """
82        if self.exists(host):
83            return self.generate_labels(host)
84        else:
85            return []
86
87
88    def get_all_labels(self):
89        """
90        Return all possible labels generated by this label class.
91
92        @returns a tuple of sets, the first set is for labels that are prefixes
93            like 'os:android'.  The second set is for labels that are full
94            labels by themselves like 'bluetooth'.
95        """
96        # Another subclass takes care of prefixed labels so this is empty.
97        prefix_labels = set()
98        full_labels_list = (self._NAME if isinstance(self._NAME, list) else
99                            [self._NAME])
100        full_labels = set(full_labels_list)
101
102        return prefix_labels, full_labels
103
104
105    def update_for_task(self, task_name):
106        """
107        This method helps to check which labels need to be updated.
108        State config labels are updated only for repair task.
109        Lab config labels are updated only for deploy task.
110        All labels are updated for any task.
111
112        It is the responsibility of the subclass to override this method
113        to differentiate itself as a state config label or a lab config label
114        and return the appropriate boolean value.
115
116        If the subclass doesn't override this method then that label will
117        always be updated for any type of task.
118
119        @returns True if labels should be updated for the task with given name
120        """
121        return True
122
123
124class StringLabel(BaseLabel):
125    """
126    This class represents a string label that is dynamically generated.
127
128    This label class is used for the types of label that are always
129    present and will return at least one label out of a list of possible labels
130    (listed in _NAME).  It is required that the subclasses implement
131    generate_labels() since the label class will need to figure out which labels
132    to return.
133
134    _NAME must always be overridden by the subclass with all the possible
135    labels that this label detection class can return in order to allow for
136    accurate label updating.
137    """
138
139    def generate_labels(self, host):
140        raise NotImplementedError('generate_labels not implemented')
141
142
143    def exists(self, host):
144        """Set to true since it is assumed the label is always applicable."""
145        return True
146
147
148class StringPrefixLabel(StringLabel):
149    """
150    This class represents a string label that is dynamically generated.
151
152    This label class is used for the types of label that usually are always
153    present and indicate the os/board/etc type of the host.  The _NAME property
154    will be prepended with a colon to the generated labels like so:
155
156        _NAME = 'os'
157        generate_label() returns ['android']
158
159    The labels returned by this label class will be ['os:android'].
160    It is important that the _NAME attribute be overridden by the
161    subclass; otherwise, all labels returned will be prefixed with 'None:'.
162    """
163
164    def get(self, host):
165        """Return the list of labels with _NAME prefixed with a colon.
166
167        @param host: The host object to check on.
168        """
169        if self.exists(host):
170            return ['%s:%s' % (self._NAME, label)
171                    for label in self.generate_labels(host)]
172        else:
173            return []
174
175
176    def get_all_labels(self):
177        """
178        Return all possible labels generated by this label class.
179
180        @returns a tuple of sets, the first set is for labels that are prefixes
181            like 'os:android'.  The second set is for labels that are full
182            labels by themselves like 'bluetooth'.
183        """
184        # Since this is a prefix label class, we only care about
185        # prefixed_labels.  We'll need to append the ':' to the label name to
186        # make sure we only match on prefix labels.
187        full_labels = set()
188        prefix_labels = set(['%s:' % self._NAME])
189
190        return prefix_labels, full_labels
191
192
193class LabelRetriever(object):
194    """This class will assist in retrieving/updating the host labels."""
195
196    def _populate_known_labels(self, label_list, task_name):
197        """Create a list of known labels that is created through this class."""
198        for label_instance in label_list:
199            # populate only the labels that need to be updated for this task.
200            if label_instance.update_for_task(task_name):
201                prefixed_labels, full_labels = label_instance.get_all_labels()
202                self.label_prefix_names.update(prefixed_labels)
203                self.label_full_names.update(full_labels)
204
205
206    def __init__(self, label_list):
207        self._labels = label_list
208        # These two sets will contain the list of labels we can safely remove
209        # during the update_labels call.
210        self.label_full_names = set()
211        self.label_prefix_names = set()
212
213
214    def get_labels(self, host):
215        """
216        Retrieve the labels for the host.
217
218        @param host: The host to get the labels for.
219        """
220        labels = []
221        for label in self._labels:
222            logging.info('checking label %s', label.__class__.__name__)
223            try:
224                labels.extend(label.get(host))
225            except Exception:
226                logging.exception('error getting label %s.',
227                                  label.__class__.__name__)
228        return labels
229
230
231    def get_labels_for_update(self, host, task_name):
232        """
233        Retrieve the labels for the host which needs to be updated.
234
235        @param host: The host to get the labels for updating.
236        @param task_name: task name(repair/deploy) for the operation.
237
238        @returns labels to be updated
239        """
240        labels = []
241        for label in self._labels:
242            try:
243                # get only the labels which need to be updated for this task.
244                if label.update_for_task(task_name):
245                    logging.info('checking label update %s',
246                                 label.__class__.__name__)
247                    labels.extend(label.get(host))
248            except Exception:
249                logging.exception('error getting label %s.',
250                                  label.__class__.__name__)
251        return labels
252
253
254    def _is_known_label(self, label):
255        """
256        Checks if the label is a label known to the label detection framework.
257
258        @param label: The label to check if we want to skip or not.
259
260        @returns True to skip (which means to keep this label, False to remove.
261        """
262        return (label in self.label_full_names or
263                any([label.startswith(p) for p in self.label_prefix_names]))
264
265
266    def _carry_over_unknown_labels(self, old_labels, new_labels):
267        """Update new_labels by adding back old unknown labels.
268
269        We only delete labels that we might have created earlier.  There are
270        some labels we should not be removing (e.g. pool:bvt) that we
271        want to keep but won't be part of the new labels detected on the host.
272        To do that we compare the passed in label to our list of known labels
273        and if we get a match, we feel safe knowing we can remove the label.
274        Otherwise we leave that label alone since it was generated elsewhere.
275
276        @param old_labels: List of labels already on the host.
277        @param new_labels: List of newly detected labels. This list will be
278                updated to add back labels that are not tracked by the detection
279                framework.
280        """
281        missing_labels = set(old_labels) - set(new_labels)
282        for label in missing_labels:
283            if not self._is_known_label(label):
284                new_labels.append(label)
285
286
287    def _commit_info(self, host, new_info, keep_pool):
288        if keep_pool and isinstance(host.host_info_store,
289                                    shadowing_store.ShadowingStore):
290            primary_store = afe_store.AfeStoreKeepPool(host.hostname)
291            host.host_info_store.commit_with_substitute(
292                    new_info,
293                    primary_store=primary_store,
294                    shadow_store=None)
295            return
296
297        host.host_info_store.commit(new_info)
298
299
300    def update_labels(self, host, task_name='', keep_pool=False):
301        """
302        Retrieve the labels from the host and update if needed.
303
304        @param host: The host to update the labels for.
305        """
306        # If we haven't yet grabbed our list of known labels, do so now.
307        if not self.label_full_names and not self.label_prefix_names:
308            self._populate_known_labels(self._labels, task_name)
309
310        # Label detection hits the DUT so it can be slow. Do it before reading
311        # old labels from HostInfoStore to minimize the time between read and
312        # commit of the HostInfo.
313        new_labels = self.get_labels_for_update(host, task_name)
314        old_info = host.host_info_store.get()
315        self._carry_over_unknown_labels(old_info.labels, new_labels)
316        new_info = host_info.HostInfo(
317                labels=new_labels,
318                attributes=old_info.attributes,
319                stable_versions=old_info.stable_versions,
320        )
321        if old_info != new_info:
322            self._commit_info(host, new_info, keep_pool)
323