1# Copyright 2019 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Module containing different output formatters for the bloat script."""
15
16import abc
17import enum
18from typing import (Callable, Collection, Dict, List, Optional, Tuple, Type,
19                    Union)
20
21from pw_bloat.binary_diff import BinaryDiff, FormattedDiff
22
23
24class Output(abc.ABC):
25    """An Output produces a size report card in a specific format."""
26    def __init__(self,
27                 title: Optional[str],
28                 diffs: Collection[BinaryDiff] = ()):
29        self._title = title
30        self._diffs = diffs
31
32    @abc.abstractmethod
33    def diff(self) -> str:
34        """Creates a report card for a size diff between binaries and a base."""
35
36    @abc.abstractmethod
37    def absolute(self) -> str:
38        """Creates a report card for the absolute size breakdown of binaries."""
39
40
41class AsciiCharset(enum.Enum):
42    """Set of ASCII characters for drawing tables."""
43    TL = '+'
44    TM = '+'
45    TR = '+'
46    ML = '+'
47    MM = '+'
48    MR = '+'
49    BL = '+'
50    BM = '+'
51    BR = '+'
52    V = '|'
53    H = '-'
54    HH = '='
55
56
57class LineCharset(enum.Enum):
58    """Set of line-drawing characters for tables."""
59    TL = '┌'
60    TM = '┬'
61    TR = '┐'
62    ML = '├'
63    MM = '┼'
64    MR = '┤'
65    BL = '└'
66    BM = '┴'
67    BR = '┘'
68    V = '│'
69    H = '─'
70    HH = '═'
71
72
73def identity(val: str) -> str:
74    """Returns a string unmodified."""
75    return val
76
77
78class TableOutput(Output):
79    """Tabular output."""
80
81    LABEL_COLUMN = 'Label'
82
83    def __init__(
84            self,
85            title: Optional[str],
86            diffs: Collection[BinaryDiff] = (),
87            charset: Union[Type[AsciiCharset],
88                           Type[LineCharset]] = AsciiCharset,
89            preprocess: Callable[[str], str] = identity,
90            # TODO(frolv): Make this a Literal type.
91            justify: str = 'rjust'):
92        self._cs = charset
93        self._preprocess = preprocess
94        self._justify = justify
95
96        super().__init__(title, diffs)
97
98    def diff(self) -> str:
99        """Build a tabular diff output showing binary size deltas."""
100
101        # Calculate the width of each column in the table.
102        max_label = len(self.LABEL_COLUMN)
103        column_widths = [len(field) for field in FormattedDiff._fields]
104
105        for diff in self._diffs:
106            max_label = max(max_label, len(diff.label))
107            for segment in diff.formatted_segments():
108                for i, val in enumerate(segment):
109                    val = self._preprocess(val)
110                    column_widths[i] = max(column_widths[i], len(val))
111
112        separators = self._row_separators([max_label] + column_widths)
113
114        def title_pad(string: str) -> str:
115            padding = (len(separators['top']) - len(string)) // 2
116            return ' ' * padding + string
117
118        titles = [
119            self._center_align(val.capitalize(), column_widths[i])
120            for i, val in enumerate(FormattedDiff._fields)
121        ]
122        column_names = [self._center_align(self.LABEL_COLUMN, max_label)
123                        ] + titles
124
125        rows: List[str] = []
126
127        if self._title is not None:
128            rows.extend([
129                title_pad(self._title),
130                title_pad(self._cs.H.value * len(self._title)),
131            ])
132
133        rows.extend([
134            separators['top'],
135            self._table_row(column_names),
136            separators['hdg'],
137        ])
138
139        for row, diff in enumerate(self._diffs):
140            subrows: List[str] = []
141
142            for segment in diff.formatted_segments():
143                subrow: List[str] = []
144                label = diff.label if not subrows else ''
145                subrow.append(getattr(label, self._justify)(max_label, ' '))
146                subrow.extend([
147                    getattr(self._preprocess(val),
148                            self._justify)(column_widths[i], ' ')
149                    for i, val in enumerate(segment)
150                ])
151                subrows.append(self._table_row(subrow))
152
153            rows.append('\n'.join(subrows))
154            rows.append(separators['bot' if row == len(self._diffs) -
155                                   1 else 'mid'])
156
157        return '\n'.join(rows)
158
159    def absolute(self) -> str:
160        return ''
161
162    def _row_separators(self, column_widths: List[int]) -> Dict[str, str]:
163        """Returns row separators for a table based on the character set."""
164
165        # Left, middle, and right characters for each of the separator rows.
166        top = (self._cs.TL.value, self._cs.TM.value, self._cs.TR.value)
167        mid = (self._cs.ML.value, self._cs.MM.value, self._cs.MR.value)
168        bot = (self._cs.BL.value, self._cs.BM.value, self._cs.BR.value)
169
170        def sep(chars: Tuple[str, str, str], heading: bool = False) -> str:
171            line = self._cs.HH.value if heading else self._cs.H.value
172            lines = [line * width for width in column_widths]
173            left = f'{chars[0]}{line}'
174            mid = f'{line}{chars[1]}{line}'.join(lines)
175            right = f'{line}{chars[2]}'
176            return f'{left}{mid}{right}'
177
178        return {
179            'top': sep(top),
180            'hdg': sep(mid, True),
181            'mid': sep(mid),
182            'bot': sep(bot),
183        }
184
185    def _table_row(self, vals: Collection[str]) -> str:
186        """Formats a row of the table with the selected character set."""
187        vert = self._cs.V.value
188        main = f' {vert} '.join(vals)
189        return f'{vert} {main} {vert}'
190
191    @staticmethod
192    def _center_align(val: str, width: int) -> str:
193        """Left and right pads a value with spaces to center within a width."""
194        space = width - len(val)
195        padding = ' ' * (space // 2)
196        extra = ' ' if space % 2 == 1 else ''
197        return f'{extra}{padding}{val}{padding}'
198
199
200class RstOutput(TableOutput):
201    """Tabular output in ASCII format, which is also valid RST."""
202    def __init__(self, diffs: Collection[BinaryDiff] = ()):
203        # Use RST line blocks within table cells to force each value to appear
204        # on a new line in the HTML output.
205        def add_rst_block(val: str) -> str:
206            return f'| {val}'
207
208        super().__init__(None,
209                         diffs,
210                         AsciiCharset,
211                         preprocess=add_rst_block,
212                         justify='ljust')
213