1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3# Copyright 2020 The Chromium OS Authors. All rights reserved.
4# Use of this source code is governed by a BSD-style license that can be
5# found in the LICENSE file.
6
7"""Checks for reverts of commits across a given git commit.
8
9To clarify the meaning of 'across' with an example, if we had the following
10commit history (where `a -> b` notes that `b` is a direct child of `a`):
11
12123abc -> 223abc -> 323abc -> 423abc -> 523abc
13
14And where 423abc is a revert of 223abc, this revert is considered to be 'across'
15323abc. More generally, a revert A of a parent commit B is considered to be
16'across' a commit C if C is a parent of A and B is a parent of C.
17
18Please note that revert detection in general is really difficult, since merge
19conflicts/etc always introduce _some_ amount of fuzziness. This script just
20uses a bundle of heuristics, and is bound to ignore / incorrectly flag some
21reverts. The hope is that it'll easily catch the vast majority (>90%) of them,
22though.
23"""
24
25# pylint: disable=cros-logging-import
26
27from __future__ import print_function
28
29import argparse
30import collections
31import logging
32import re
33import subprocess
34import sys
35import typing as t
36
37# People are creative with their reverts, and heuristics are a bit difficult.
38# Like 90% of of reverts have "This reverts commit ${full_sha}".
39# Some lack that entirely, while others have many of them specified in ad-hoc
40# ways, while others use short SHAs and whatever.
41#
42# The 90% case is trivial to handle (and 100% free + automatic). The extra 10%
43# starts involving human intervention, which is probably not worth it for now.
44
45
46def _try_parse_reverts_from_commit_message(commit_message: str) -> t.List[str]:
47  if not commit_message:
48    return []
49
50  results = re.findall(r'This reverts commit ([a-f0-9]{40})\b', commit_message)
51
52  first_line = commit_message.splitlines()[0]
53  initial_revert = re.match(r'Revert ([a-f0-9]{6,}) "', first_line)
54  if initial_revert:
55    results.append(initial_revert.group(1))
56  return results
57
58
59def _stream_stdout(command: t.List[str]) -> t.Generator[str, None, None]:
60  with subprocess.Popen(
61      command, stdout=subprocess.PIPE, encoding='utf-8', errors='replace') as p:
62    yield from p.stdout
63
64
65def _resolve_sha(git_dir: str, sha: str) -> str:
66  if len(sha) == 40:
67    return sha
68
69  return subprocess.check_output(
70      ['git', '-C', git_dir, 'rev-parse', sha],
71      encoding='utf-8',
72      stderr=subprocess.DEVNULL,
73  ).strip()
74
75
76_LogEntry = t.NamedTuple('_LogEntry', [
77    ('sha', str),
78    ('commit_message', t.List[str]),
79])
80
81
82def _log_stream(git_dir: str, root_sha: str,
83                end_at_sha: str) -> t.Iterable[_LogEntry]:
84  sep = 50 * '<>'
85  log_command = [
86      'git',
87      '-C',
88      git_dir,
89      'log',
90      '^' + end_at_sha,
91      root_sha,
92      '--format=' + sep + '%n%H%n%B%n',
93  ]
94
95  stdout_stream = iter(_stream_stdout(log_command))
96
97  # Find the next separator line. If there's nothing to log, it may not exist.
98  # It might not be the first line if git feels complainy.
99  found_commit_header = False
100  for line in stdout_stream:
101    if line.rstrip() == sep:
102      found_commit_header = True
103      break
104
105  while found_commit_header:
106    # crbug.com/1041148
107    # pylint: disable=stop-iteration-return
108    sha = next(stdout_stream, None)
109    assert sha is not None, 'git died?'
110    sha = sha.rstrip()
111
112    commit_message = []
113
114    found_commit_header = False
115    for line in stdout_stream:
116      line = line.rstrip()
117      if line.rstrip() == sep:
118        found_commit_header = True
119        break
120      commit_message.append(line)
121
122    yield _LogEntry(sha, '\n'.join(commit_message).rstrip())
123
124
125def _shas_between(git_dir: str, base_ref: str,
126                  head_ref: str) -> t.Iterable[str]:
127  rev_list = [
128      'git',
129      '-C',
130      git_dir,
131      'rev-list',
132      '--first-parent',
133      '%s..%s' % (base_ref, head_ref),
134  ]
135  return (x.strip() for x in _stream_stdout(rev_list))
136
137
138def _rev_parse(git_dir: str, ref: str) -> str:
139  result = subprocess.check_output(
140      ['git', '-C', git_dir, 'rev-parse', ref],
141      encoding='utf-8',
142  ).strip()
143  return t.cast(str, result)
144
145
146Revert = t.NamedTuple('Revert', [
147    ('sha', str),
148    ('reverted_sha', str),
149])
150
151
152def find_common_parent_commit(git_dir: str, ref_a: str, ref_b: str) -> str:
153  return subprocess.check_output(
154      ['git', '-C', git_dir, 'merge-base', ref_a, ref_b],
155      encoding='utf-8',
156  ).strip()
157
158
159def find_reverts(git_dir: str, across_ref: str, root: str) -> t.List[Revert]:
160  """Finds reverts across `across_ref` in `git_dir`, starting from `root`."""
161  across_sha = _rev_parse(git_dir, across_ref)
162  root_sha = _rev_parse(git_dir, root)
163
164  common_ancestor = find_common_parent_commit(git_dir, across_sha, root_sha)
165  if common_ancestor != across_sha:
166    raise ValueError("%s isn't an ancestor of %s (common ancestor: %s)" %
167                     (across_sha, root_sha, common_ancestor))
168
169  intermediate_commits = set(_shas_between(git_dir, across_sha, root_sha))
170  assert across_ref not in intermediate_commits
171
172  logging.debug('%d commits appear between %s and %s',
173                len(intermediate_commits), across_sha, root_sha)
174
175  all_reverts = []
176  for sha, commit_message in _log_stream(git_dir, root_sha, across_sha):
177    reverts = _try_parse_reverts_from_commit_message(commit_message)
178    if not reverts:
179      continue
180
181    resolved_reverts = sorted(set(_resolve_sha(git_dir, x) for x in reverts))
182    for reverted_sha in resolved_reverts:
183      if reverted_sha in intermediate_commits:
184        logging.debug('Commit %s reverts %s, which happened after %s', sha,
185                      reverted_sha, across_sha)
186        continue
187
188      try:
189        object_type = subprocess.check_output(
190            ['git', '-C', git_dir, 'cat-file', '-t', reverted_sha],
191            encoding='utf-8',
192            stderr=subprocess.DEVNULL,
193        ).strip()
194      except subprocess.CalledProcessError:
195        logging.warning(
196            'Failed to resolve reverted object %s (claimed to be reverted '
197            'by sha %s)', reverted_sha, sha)
198        continue
199
200      if object_type == 'commit':
201        all_reverts.append(Revert(sha, reverted_sha))
202        continue
203
204      logging.error("%s claims to revert %s -- which isn't a commit -- %s", sha,
205                    object_type, reverted_sha)
206
207  return all_reverts
208
209
210def main(args: t.List[str]) -> int:
211  parser = argparse.ArgumentParser(
212      description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
213  parser.add_argument(
214      'base_ref', help='Git ref or sha to check for reverts around.')
215  parser.add_argument(
216      '-C', '--git_dir', default='.', help='Git directory to use.')
217  parser.add_argument(
218      'root', nargs='+', help='Root(s) to search for commits from.')
219  parser.add_argument('--debug', action='store_true')
220  opts = parser.parse_args(args)
221
222  logging.basicConfig(
223      format='%(asctime)s: %(levelname)s: %(filename)s:%(lineno)d: %(message)s',
224      level=logging.DEBUG if opts.debug else logging.INFO,
225  )
226
227  # `root`s can have related history, so we want to filter duplicate commits
228  # out. The overwhelmingly common case is also to have one root, and it's way
229  # easier to reason about output that comes in an order that's meaningful to
230  # git.
231  all_reverts = collections.OrderedDict()
232  for root in opts.root:
233    for revert in find_reverts(opts.git_dir, opts.base_ref, root):
234      all_reverts[revert] = None
235
236  for revert in all_reverts.keys():
237    print('%s claims to revert %s' % (revert.sha, revert.reverted_sha))
238
239
240if __name__ == '__main__':
241  sys.exit(main(sys.argv[1:]))
242