1#!/usr/bin/env python3
2#  Copyright 2016 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS-IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import argparse
17import json
18import yaml
19from collections import defaultdict
20
21def extract_results(bench_results, fixed_benchmark_params, column_dimension, row_dimension, result_dimension):
22    table_data = defaultdict(lambda: dict())
23    remaining_dimensions_by_row_column = dict()
24    for bench_result in bench_results:
25        try:
26            params = {dimension_name: make_immutable(dimension_value)
27                      for dimension_name, dimension_value in bench_result['benchmark'].items()}
28            results = bench_result['results']
29            for param_name, param_value in fixed_benchmark_params.items():
30                if params.get(param_name) != param_value:
31                    # fixed_benchmark_params not satisfied by this result, skip
32                    break
33                if result_dimension not in results:
34                    # result_dimension not found in this result, skip
35                    break
36                params.pop(param_name)
37            else:
38                # fixed_benchmark_params were satisfied by these params (and were removed)
39                assert row_dimension in params.keys(), '%s not in %s' % (row_dimension, params.keys())
40                assert column_dimension in params.keys(), '%s not in %s' % (column_dimension, params.keys())
41                assert result_dimension in results, '%s not in %s' % (result_dimension, results)
42                row_value = params[row_dimension]
43                column_value = params[column_dimension]
44                remaining_dimensions = params.copy()
45                remaining_dimensions.pop(row_dimension)
46                remaining_dimensions.pop(column_dimension)
47                if column_value in table_data[row_value]:
48                    previous_remaining_dimensions = remaining_dimensions_by_row_column[(row_value, column_value)]
49                    raise Exception(
50                        'Found multiple benchmark results with the same fixed benchmark params, benchmark param for row and benchmark param for column, so a result can\'t be uniquely determined. '
51                        + 'Consider adding additional values in fixed_benchmark_params. Remaining dimensions: %s vs %s' % (
52                        remaining_dimensions, previous_remaining_dimensions))
53                table_data[row_value][column_value] = results[result_dimension]
54                remaining_dimensions_by_row_column[(row_value, column_value)] = remaining_dimensions
55        except Exception as e:
56            raise Exception('While processing %s' % bench_result) from e
57    return table_data
58
59
60def identity(x):
61    return x
62
63
64# Takes a 2-dimensional array (list of lists) and prints a markdown table with that content.
65def print_markdown_table(table_data):
66    max_content_length_by_column = [max([len(str(row[column_index])) for row in table_data])
67                                    for column_index in range(len(table_data[0]))]
68    for row_index in range(len(table_data)):
69        row = table_data[row_index]
70        cell_strings = []
71        for column_index in range(len(row)):
72            value = str(row[column_index])
73            # E.g. if max_content_length_by_column=20, table_cell_format='%20s'
74            table_cell_format = '%%%ss' % max_content_length_by_column[column_index]
75            cell_strings += [table_cell_format % value]
76        print('| ' + ' | '.join(cell_strings) + ' |')
77        if row_index == 0:
78            # Print the separator line, e.g. |---|-----|---|
79            print('|-'
80                  + '-|-'.join(['-' * max_content_length_by_column[column_index]
81                                for column_index in range(len(row))])
82                  + '-|')
83
84def compute_min_max(table_data, row_headers, column_headers):
85    values_by_row = {row_header: [table_data[row_header][column_header]
86                                  for column_header in column_headers
87                                  if column_header in table_data[row_header]]
88                     for row_header in row_headers}
89    # We compute min and max and pass it to the value pretty-printer, so that it can determine a unit that works well for all values in the table.
90    min_in_table = min([min([min(interval[0][0], interval[1][0]) for interval in values_by_row[row_header]])
91                        for row_header in row_headers])
92    max_in_table = max([max([max(interval[0][1], interval[1][1]) for interval in values_by_row[row_header]])
93                        for row_header in row_headers])
94    return (min_in_table, max_in_table)
95
96
97def pretty_print_percentage_difference(baseline_value, current_value):
98    baseline_min = baseline_value[0]
99    baseline_max = baseline_value[1]
100    current_min = current_value[0]
101    current_max = current_value[1]
102    percentage_min = (current_min / baseline_max - 1) * 100
103    percentage_max = (current_max / baseline_min - 1) * 100
104    percentage_min_s = "%+.1f%%" % percentage_min
105    percentage_max_s = "%+.1f%%" % percentage_max
106    if percentage_min_s == percentage_max_s:
107        return percentage_min_s
108    else:
109        return "%s - %s" % (percentage_min_s, percentage_max_s)
110
111
112# Takes a table as a dict of dicts (where each table_data[row_key][column_key] is a confidence interval) and prints it as a markdown table using
113# the specified pretty print functions for column keys, row keys and values respectively.
114# column_header_pretty_printer and row_header_pretty_printer must be functions taking a single value and returning the pretty-printed version.
115# value_pretty_printer must be a function taking (value_confidence_interval, min_in_table, max_in_table).
116# baseline_table_data is an optional table (similar to table_data) that contains the "before" state. If present, the values in two tables will be compared.
117def print_confidence_intervals_table(table_name,
118                                     table_data,
119                                     baseline_table_data,
120                                     column_header_pretty_printer=identity,
121                                     row_header_pretty_printer=identity,
122                                     value_pretty_printer=identity):
123    if table_data == {}:
124        print('%s: (no data)' % table_name)
125        return
126
127    row_headers = sorted(list(table_data.keys()))
128    # We need to compute the union of the headers of all rows; some rows might be missing values for certain columns.
129    column_headers = sorted(set().union(*[list(row_values.keys()) for row_values in table_data.values()]))
130    if baseline_table_data:
131        baseline_row_headers = sorted(list(baseline_table_data.keys()))
132        baseline_column_headers = sorted(set().union(*[list(row_values.keys()) for row_values in baseline_table_data.values()]))
133        unmached_baseline_column_headers = set(baseline_row_headers) - set(row_headers)
134        if unmached_baseline_column_headers:
135            print('Found baseline column headers with no match in new results (they will be ignored): ', unmached_baseline_column_headers)
136        unmached_baseline_row_headers = set(baseline_row_headers) - set(row_headers)
137        if unmached_baseline_row_headers:
138            print('Found baseline row headers with no match in new results (they will be ignored): ', unmached_baseline_row_headers)
139
140    min_in_table, max_in_table = compute_min_max(table_data, row_headers, column_headers)
141    if baseline_table_data:
142        min_in_baseline_table, max_in_baseline_table = compute_min_max(table_data, row_headers, column_headers)
143        min_in_table = min(min_in_table, min_in_baseline_table)
144        max_in_table = max(max_in_table, max_in_baseline_table)
145
146    table_content = []
147    table_content.append([table_name] + [column_header_pretty_printer(column_header) for column_header in column_headers])
148    for row_header in row_headers:
149        row_content = [row_header_pretty_printer(row_header)]
150        for column_header in column_headers:
151            if column_header in table_data[row_header]:
152                value = table_data[row_header][column_header]
153                raw_confidence_interval, rounded_confidence_interval = value
154                pretty_printed_value = value_pretty_printer(rounded_confidence_interval, min_in_table, max_in_table)
155                if baseline_table_data and row_header in baseline_table_data and column_header in baseline_table_data[row_header]:
156                    baseline_value = baseline_table_data[row_header][column_header]
157                    raw_baseline_confidence_interval, rounded_baseline_confidence_interval = baseline_value
158                    pretty_printed_baseline_value = value_pretty_printer(rounded_baseline_confidence_interval, min_in_table, max_in_table)
159                    pretty_printed_percentage_difference = pretty_print_percentage_difference(raw_baseline_confidence_interval, raw_confidence_interval)
160                    row_content.append("%s -> %s (%s)" % (pretty_printed_baseline_value, pretty_printed_value, pretty_printed_percentage_difference))
161                else:
162                    row_content.append(pretty_printed_value)
163            else:
164                row_content.append("N/A")
165        table_content.append(row_content)
166    print_markdown_table(table_content)
167
168
169def format_string_pretty_printer(format_string):
170    def pretty_print(s):
171        return format_string % s
172
173    return pretty_print
174
175
176def interval_pretty_printer(interval, unit, multiplier):
177    interval = interval.copy()
178    interval[0] *= multiplier
179    interval[1] *= multiplier
180
181    # This prevents the format strings below from printing '.0' for numbers that already have 2 digits:
182    # 23.0 -> 23
183    # 2.0 -> 2.0 (here we don't remove the '.0' because printing just '2' might suggest a lower precision)
184    if int(interval[0]) == interval[0] and interval[0] >= 10:
185        interval[0] = int(interval[0])
186    else:
187        interval[0] = '%.3g' % interval[0]
188    if int(interval[1]) == interval[1] and interval[1] >= 10:
189        interval[1] = int(interval[1])
190    else:
191        interval[1] = '%.3g' % interval[1]
192
193    if interval[0] == interval[1]:
194        return '%s %s' % (interval[0], unit)
195    else:
196        return '%s-%s %s' % (interval[0], interval[1], unit)
197
198
199# Finds the best unit to represent values in the range [min_value, max_value].
200# The units must be specified as an ordered list [multiplier1, ..., multiplierN]
201def find_best_unit(units, min_value, max_value):
202    assert min_value <= max_value
203    if max_value <= units[0]:
204        return units[0]
205    for i in range(len(units) - 1):
206        if min_value > units[i] and max_value < units[i + 1]:
207            return units[i]
208    if min_value > units[-1]:
209        return units[-1]
210    # There is no unit that works very well for all values, first let's try relaxing the min constraint
211    for i in range(len(units) - 1):
212        if min_value > units[i] * 0.2 and max_value < units[i + 1]:
213            return units[i]
214    if min_value > units[-1] * 0.2:
215        return units[-1]
216    # That didn't work either, just use a unit that works well for the min values then
217    for i in reversed(range(len(units))):
218        if min_value > units[i]:
219            return units[i]
220    assert min_value <= min(units)
221    # Pick the smallest unit
222    return units[0]
223
224
225def time_interval_pretty_printer(time_interval, min_in_table, max_in_table):
226    sec = 1
227    milli = 0.001
228    micro = milli * milli
229    units = [micro, milli, sec]
230    unit_name_by_unit = {micro: 'μs', milli: 'ms', sec: 's'}
231
232    unit = find_best_unit(units, min_in_table, max_in_table)
233    unit_name = unit_name_by_unit[unit]
234
235    return interval_pretty_printer(time_interval, unit=unit_name, multiplier=1 / unit)
236
237
238def file_size_interval_pretty_printer(file_size_interval, min_in_table, max_in_table):
239    byte = 1
240    kb = 1024
241    mb = kb * kb
242    units = [byte, kb, mb]
243    unit_name_by_unit = {byte: 'bytes', kb: 'KB', mb: 'MB'}
244
245    unit = find_best_unit(units, min_in_table, max_in_table)
246    unit_name = unit_name_by_unit[unit]
247
248    return interval_pretty_printer(file_size_interval, unit=unit_name, multiplier=1 / unit)
249
250
251def make_immutable(x):
252    if isinstance(x, list):
253        return tuple(make_immutable(elem) for elem in x)
254    return x
255
256
257def dict_pretty_printer(dict_data):
258    if isinstance(dict_data, list):
259        dict_data = {make_immutable(mapping['from']): mapping['to'] for mapping in dict_data}
260    def pretty_print(s):
261        if s in dict_data:
262            return dict_data[s]
263        else:
264            raise Exception('dict_pretty_printer(%s) can\'t handle the value %s' % (dict_data, s))
265
266    return pretty_print
267
268
269def determine_column_pretty_printer(pretty_printer_definition):
270    if 'format_string' in pretty_printer_definition:
271        return format_string_pretty_printer(pretty_printer_definition['format_string'])
272
273    if 'fixed_map' in pretty_printer_definition:
274        return dict_pretty_printer(pretty_printer_definition['fixed_map'])
275
276    raise Exception("Unrecognized pretty printer description: %s" % pretty_printer_definition)
277
278
279def determine_row_pretty_printer(pretty_printer_definition):
280    return determine_column_pretty_printer(pretty_printer_definition)
281
282
283def determine_value_pretty_printer(unit):
284    if unit == "seconds":
285        return time_interval_pretty_printer
286    if unit == "bytes":
287        return file_size_interval_pretty_printer
288    raise Exception("Unrecognized unit: %s" % unit)
289
290
291def main():
292    parser = argparse.ArgumentParser(description='Runs all the benchmarks whose results are on the Fruit website.')
293    parser.add_argument('--benchmark-results',
294                        help='The input file where benchmark results will be read from (1 per line, with each line in JSON format). You can use the run_benchmarks.py to run a benchmark and generate results in this format.')
295    parser.add_argument('--baseline-benchmark-results',
296                        help='Optional. If specified, compares this file (considered the "before" state) with the one specified in --benchmark-results.')
297    parser.add_argument('--benchmark-tables-definition', help='The YAML file that defines the benchmark tables (e.g. fruit_wiki_bench_tables.yaml).')
298    args = parser.parse_args()
299
300    if args.benchmark_results is None:
301        raise Exception("You must specify a benchmark results file using --benchmark-results.")
302
303    if args.benchmark_tables_definition is None:
304        raise Exception("You must specify a benchmark tables definition file using --benchmark-tables-definition.")
305
306    with open(args.benchmark_results, 'r') as f:
307        bench_results = [json.loads(line) for line in f.readlines()]
308
309    if args.baseline_benchmark_results:
310        with open(args.baseline_benchmark_results, 'r') as f:
311            baseline_bench_results = [json.loads(line) for line in f.readlines()]
312    else:
313        baseline_bench_results = None
314
315
316    with open(args.benchmark_tables_definition, 'r') as f:
317        for table_definition in yaml.load(f)["tables"]:
318            try:
319                fixed_benchmark_params = {dimension_name: make_immutable(dimension_value) for dimension_name, dimension_value in table_definition['benchmark_filter'].items()}
320                table_data = extract_results(
321                    bench_results,
322                    fixed_benchmark_params=fixed_benchmark_params,
323                    column_dimension=table_definition['columns']['dimension'],
324                    row_dimension=table_definition['rows']['dimension'],
325                    result_dimension=table_definition['results']['dimension'])
326                if baseline_bench_results:
327                    baseline_table_data = extract_results(
328                        baseline_bench_results,
329                        fixed_benchmark_params=fixed_benchmark_params,
330                        column_dimension=table_definition['columns']['dimension'],
331                        row_dimension=table_definition['rows']['dimension'],
332                        result_dimension=table_definition['results']['dimension'])
333                else:
334                    baseline_table_data = None
335                rows_pretty_printer_definition = table_definition['rows']['pretty_printer']
336                columns_pretty_printer_definition = table_definition['columns']['pretty_printer']
337                results_unit = table_definition['results']['unit']
338                print_confidence_intervals_table(table_definition['name'],
339                                                 table_data,
340                                                 baseline_table_data,
341                                                 column_header_pretty_printer=determine_column_pretty_printer(columns_pretty_printer_definition),
342                                                 row_header_pretty_printer=determine_row_pretty_printer(rows_pretty_printer_definition),
343                                                 value_pretty_printer=determine_value_pretty_printer(results_unit))
344                print()
345                print()
346            except Exception as e:
347                print('While processing table:\n' + table_definition)
348                print()
349                raise e
350
351
352if __name__ == "__main__":
353    main()
354