1from autotest_lib.frontend.afe import rpc_utils
2from autotest_lib.client.common_lib import kernel_versions
3from autotest_lib.frontend.tko import models
4
5class TooManyRowsError(Exception):
6    """
7    Raised when a database query returns too many rows.
8    """
9
10
11class KernelString(str):
12    """
13    Custom string class that uses correct kernel version comparisons.
14    """
15    def _map(self):
16        return kernel_versions.version_encode(self)
17
18
19    def __hash__(self):
20        return hash(self._map())
21
22
23    def __eq__(self, other):
24        return self._map() == other._map()
25
26
27    def __ne__(self, other):
28        return self._map() != other._map()
29
30
31    def __lt__(self, other):
32        return self._map() < other._map()
33
34
35    def __lte__(self, other):
36        return self._map() <= other._map()
37
38
39    def __gt__(self, other):
40        return self._map() > other._map()
41
42
43    def __gte__(self, other):
44        return self._map() >= other._map()
45
46
47# SQL expression to compute passed test count for test groups
48_PASS_COUNT_NAME = 'pass_count'
49_COMPLETE_COUNT_NAME = 'complete_count'
50_INCOMPLETE_COUNT_NAME = 'incomplete_count'
51# Using COUNT instead of SUM here ensures the resulting row has the right type
52# (i.e. numeric, not string).  I don't know why.
53_PASS_COUNT_SQL = 'COUNT(IF(status="GOOD", 1, NULL))'
54_COMPLETE_COUNT_SQL = ('COUNT(IF(status NOT IN ("TEST_NA", "RUNNING", '
55                                               '"NOSTATUS"), 1, NULL))')
56_INCOMPLETE_COUNT_SQL = 'COUNT(IF(status="RUNNING", 1, NULL))'
57STATUS_FIELDS = {_PASS_COUNT_NAME : _PASS_COUNT_SQL,
58                 _COMPLETE_COUNT_NAME : _COMPLETE_COUNT_SQL,
59                 _INCOMPLETE_COUNT_NAME : _INCOMPLETE_COUNT_SQL}
60_INVALID_STATUSES = ('TEST_NA', 'NOSTATUS')
61
62
63def add_status_counts(group_dict, status):
64    pass_count = complete_count = incomplete_count = 0
65    if status == 'GOOD':
66        pass_count = complete_count = 1
67    elif status == 'RUNNING':
68        incomplete_count = 1
69    else:
70        complete_count = 1
71    group_dict[_PASS_COUNT_NAME] = pass_count
72    group_dict[_COMPLETE_COUNT_NAME] = complete_count
73    group_dict[_INCOMPLETE_COUNT_NAME] = incomplete_count
74    group_dict[models.TestView.objects._GROUP_COUNT_NAME] = 1
75
76
77def _construct_machine_label_header_sql(machine_labels):
78    """
79    Example result for machine_labels=['Index', 'Diskful']:
80    CONCAT_WS(",",
81              IF(FIND_IN_SET("Diskful", tko_test_attributes_host_labels.value),
82                 "Diskful", NULL),
83              IF(FIND_IN_SET("Index", tko_test_attributes_host_labels.value),
84                 "Index", NULL))
85
86    This would result in field values "Diskful,Index", "Diskful", "Index", NULL.
87    """
88    machine_labels = sorted(machine_labels)
89    if_clauses = []
90    for label in machine_labels:
91        if_clauses.append(
92            'IF(FIND_IN_SET("%s", tko_test_attributes_host_labels.value), '
93               '"%s", NULL)' % (label, label))
94    return 'CONCAT_WS(",", %s)' % ', '.join(if_clauses)
95
96
97class GroupDataProcessor(object):
98    _MAX_GROUP_RESULTS = 80000
99
100    def __init__(self, query, group_by, header_groups, fixed_headers):
101        self._query = query
102        self._group_by = self.uniqify(group_by)
103        self._header_groups = header_groups
104        self._fixed_headers = dict((field, set(values))
105                                   for field, values
106                                   in fixed_headers.iteritems())
107
108        self._num_group_fields = len(group_by)
109        self._header_value_sets = [set() for i
110                                   in xrange(len(header_groups))]
111        self._group_dicts = []
112
113
114    @staticmethod
115    def uniqify(values):
116        return list(set(values))
117
118
119    def _restrict_header_values(self):
120        for header_field, values in self._fixed_headers.iteritems():
121            self._query = self._query.filter(**{header_field + '__in' : values})
122
123
124    def _fetch_data(self):
125        self._restrict_header_values()
126        self._group_dicts = models.TestView.objects.execute_group_query(
127            self._query, self._group_by)
128
129
130    @staticmethod
131    def _get_field(group_dict, field):
132        """
133        Use special objects for certain fields to achieve custom sorting.
134        -Wrap kernel versions with a KernelString
135        -Replace null dates with special values
136        """
137        value = group_dict[field]
138        if field == 'kernel':
139            return KernelString(value)
140        if value is None: # handle null dates as later than everything else
141            if field.startswith('DATE('):
142                return rpc_utils.NULL_DATE
143            if field.endswith('_time'):
144                return rpc_utils.NULL_DATETIME
145        return value
146
147
148    def _process_group_dict(self, group_dict):
149        # compute and aggregate header groups
150        for i, group in enumerate(self._header_groups):
151            header = tuple(self._get_field(group_dict, field)
152                           for field in group)
153            self._header_value_sets[i].add(header)
154            group_dict.setdefault('header_values', []).append(header)
155
156        # frontend's SelectionManager needs a unique ID
157        group_values = [group_dict[field] for field in self._group_by]
158        group_dict['id'] = str(group_values)
159        return group_dict
160
161
162    def _find_header_value_set(self, field):
163        for i, group in enumerate(self._header_groups):
164            if [field] == group:
165                return self._header_value_sets[i]
166        raise RuntimeError('Field %s not found in header groups %s' %
167                           (field, self._header_groups))
168
169
170    def _add_fixed_headers(self):
171        for field, extra_values in self._fixed_headers.iteritems():
172            header_value_set = self._find_header_value_set(field)
173            for value in extra_values:
174                header_value_set.add((value,))
175
176
177    def _get_sorted_header_values(self):
178        self._add_fixed_headers()
179        sorted_header_values = [sorted(value_set)
180                                for value_set in self._header_value_sets]
181        # construct dicts mapping headers to their indices, for use in
182        # replace_headers_with_indices()
183        self._header_index_maps = []
184        for value_list in sorted_header_values:
185            index_map = dict((value, i) for i, value in enumerate(value_list))
186            self._header_index_maps.append(index_map)
187
188        return sorted_header_values
189
190
191    def _replace_headers_with_indices(self, group_dict):
192        group_dict['header_indices'] = [index_map[header_value]
193                                        for index_map, header_value
194                                        in zip(self._header_index_maps,
195                                               group_dict['header_values'])]
196        for field in self._group_by + ['header_values']:
197            del group_dict[field]
198
199
200    def process_group_dicts(self):
201        self._fetch_data()
202        if len(self._group_dicts) > self._MAX_GROUP_RESULTS:
203            raise TooManyRowsError(
204                'Query yielded %d rows, exceeding maximum %d' % (
205                len(self._group_dicts), self._MAX_GROUP_RESULTS))
206
207        for group_dict in self._group_dicts:
208            self._process_group_dict(group_dict)
209        self._header_values = self._get_sorted_header_values()
210        if self._header_groups:
211            for group_dict in self._group_dicts:
212                self._replace_headers_with_indices(group_dict)
213
214
215    def get_info_dict(self):
216        return {'groups' : self._group_dicts,
217                'header_values' : self._header_values}
218