1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3# Copyright 2019 The Chromium OS Authors. All rights reserved.
4# Use of this source code is governed by a BSD-style license that can be
5# found in the LICENSE file.
6
7"""Maps LLVM git SHAs to synthetic revision numbers and back.
8
9Revision numbers are all of the form '(branch_name, r1234)'. As a shorthand,
10r1234 is parsed as '(main, 1234)'.
11"""
12
13from __future__ import print_function
14
15import argparse
16import re
17import subprocess
18import sys
19import typing as t
20
21MAIN_BRANCH = 'main'
22
23# Note that after base_llvm_sha, we reach The Wild West(TM) of commits.
24# So reasonable input that could break us includes:
25#
26#   Revert foo
27#
28#   This reverts foo, which had the commit message:
29#
30#   bar
31#   llvm-svn: 375505
32#
33# While saddening, this is something we should probably try to handle
34# reasonably.
35base_llvm_revision = 375505
36base_llvm_sha = '186155b89c2d2a2f62337081e3ca15f676c9434b'
37
38# Represents an LLVM git checkout:
39#  - |dir| is the directory of the LLVM checkout
40#  - |remote| is the name of the LLVM remote. Generally it's "origin".
41LLVMConfig = t.NamedTuple('LLVMConfig', (('remote', str), ('dir', str)))
42
43
44class Rev(t.NamedTuple('Rev', (('branch', str), ('number', int)))):
45  """Represents a LLVM 'revision', a shorthand identifies a LLVM commit."""
46
47  @staticmethod
48  def parse(rev: str) -> 'Rev':
49    """Parses a Rev from the given string.
50
51    Raises a ValueError on a failed parse.
52    """
53    # Revs are parsed into (${branch_name}, r${commits_since_base_commit})
54    # pairs.
55    #
56    # We support r${commits_since_base_commit} as shorthand for
57    # (main, r${commits_since_base_commit}).
58    if rev.startswith('r'):
59      branch_name = MAIN_BRANCH
60      rev_string = rev[1:]
61    else:
62      match = re.match(r'\((.+), r(\d+)\)', rev)
63      if not match:
64        raise ValueError("%r isn't a valid revision" % rev)
65
66      branch_name, rev_string = match.groups()
67
68    return Rev(branch=branch_name, number=int(rev_string))
69
70  def __str__(self) -> str:
71    branch_name, number = self
72    if branch_name == MAIN_BRANCH:
73      return 'r%d' % number
74    return '(%s, r%d)' % (branch_name, number)
75
76
77def is_git_sha(xs: str) -> bool:
78  """Returns whether the given string looks like a valid git commit SHA."""
79  return len(xs) > 6 and len(xs) <= 40 and all(
80      x.isdigit() or 'a' <= x.lower() <= 'f' for x in xs)
81
82
83def check_output(command: t.List[str], cwd: str) -> str:
84  """Shorthand for subprocess.check_output. Auto-decodes any stdout."""
85  result = subprocess.run(
86      command,
87      cwd=cwd,
88      check=True,
89      stdin=subprocess.DEVNULL,
90      stdout=subprocess.PIPE,
91      encoding='utf-8',
92  )
93  return result.stdout
94
95
96def translate_prebase_sha_to_rev_number(llvm_config: LLVMConfig,
97                                        sha: str) -> int:
98  """Translates a sha to a revision number (e.g., "llvm-svn: 1234").
99
100  This function assumes that the given SHA is an ancestor of |base_llvm_sha|.
101  """
102  commit_message = check_output(
103      ['git', 'log', '-n1', '--format=%B', sha],
104      cwd=llvm_config.dir,
105  )
106  last_line = commit_message.strip().splitlines()[-1]
107  svn_match = re.match(r'^llvm-svn: (\d+)$', last_line)
108
109  if not svn_match:
110    raise ValueError(
111        f"No llvm-svn line found for {sha}, which... shouldn't happen?")
112
113  return int(svn_match.group(1))
114
115
116def translate_sha_to_rev(llvm_config: LLVMConfig, sha_or_ref: str) -> Rev:
117  """Translates a sha or git ref to a Rev."""
118
119  if is_git_sha(sha_or_ref):
120    sha = sha_or_ref
121  else:
122    sha = check_output(
123        ['git', 'rev-parse', sha_or_ref],
124        cwd=llvm_config.dir,
125    )
126    sha = sha.strip()
127
128  merge_base = check_output(
129      ['git', 'merge-base', base_llvm_sha, sha],
130      cwd=llvm_config.dir,
131  )
132  merge_base = merge_base.strip()
133
134  if merge_base == base_llvm_sha:
135    result = check_output(
136        [
137            'git',
138            'rev-list',
139            '--count',
140            '--first-parent',
141            f'{base_llvm_sha}..{sha}',
142        ],
143        cwd=llvm_config.dir,
144    )
145    count = int(result.strip())
146    return Rev(branch=MAIN_BRANCH, number=count + base_llvm_revision)
147
148  # Otherwise, either:
149  # - |merge_base| is |sha| (we have a guaranteed llvm-svn number on |sha|)
150  # - |merge_base| is neither (we have a guaranteed llvm-svn number on
151  #                            |merge_base|, but not |sha|)
152  merge_base_number = translate_prebase_sha_to_rev_number(
153      llvm_config, merge_base)
154  if merge_base == sha:
155    return Rev(branch=MAIN_BRANCH, number=merge_base_number)
156
157  distance_from_base = check_output(
158      [
159          'git',
160          'rev-list',
161          '--count',
162          '--first-parent',
163          f'{merge_base}..{sha}',
164      ],
165      cwd=llvm_config.dir,
166  )
167
168  revision_number = merge_base_number + int(distance_from_base.strip())
169  branches_containing = check_output(
170      ['git', 'branch', '-r', '--contains', sha],
171      cwd=llvm_config.dir,
172  )
173
174  candidates = []
175
176  prefix = llvm_config.remote + '/'
177  for branch in branches_containing.splitlines():
178    branch = branch.strip()
179    if branch.startswith(prefix):
180      candidates.append(branch[len(prefix):])
181
182  if not candidates:
183    raise ValueError(
184        f'No viable branches found from {llvm_config.remote} with {sha}')
185
186  if len(candidates) != 1:
187    raise ValueError(
188        f'Ambiguity: multiple branches from {llvm_config.remote} have {sha}: '
189        f'{sorted(candidates)}')
190
191  return Rev(branch=candidates[0], number=revision_number)
192
193
194def parse_git_commit_messages(stream: t.Iterable[str],
195                              separator: str) -> t.Iterable[t.Tuple[str, str]]:
196  """Parses a stream of git log messages.
197
198  These are expected to be in the format:
199
200  40 character sha
201  commit
202  message
203  body
204  separator
205  40 character sha
206  commit
207  message
208  body
209  separator
210  """
211
212  lines = iter(stream)
213  while True:
214    # Looks like a potential bug in pylint? crbug.com/1041148
215    # pylint: disable=stop-iteration-return
216    sha = next(lines, None)
217    if sha is None:
218      return
219
220    sha = sha.strip()
221    assert is_git_sha(sha), f'Invalid git SHA: {sha}'
222
223    message = []
224    for line in lines:
225      if line.strip() == separator:
226        break
227      message.append(line)
228
229    yield sha, ''.join(message)
230
231
232def translate_prebase_rev_to_sha(llvm_config: LLVMConfig, rev: Rev) -> str:
233  """Translates a Rev to a SHA.
234
235  This function assumes that the given rev refers to a commit that's an
236  ancestor of |base_llvm_sha|.
237  """
238  # Because reverts may include reverted commit messages, we can't just |-n1|
239  # and pick that.
240  separator = '>!' * 80
241  looking_for = f'llvm-svn: {rev.number}'
242
243  git_command = [
244      'git', 'log', '--grep', f'^{looking_for}$', f'--format=%H%n%B{separator}',
245      base_llvm_sha
246  ]
247
248  subp = subprocess.Popen(
249      git_command,
250      cwd=llvm_config.dir,
251      stdin=subprocess.DEVNULL,
252      stdout=subprocess.PIPE,
253      encoding='utf-8',
254  )
255
256  with subp:
257    for sha, message in parse_git_commit_messages(subp.stdout, separator):
258      last_line = message.splitlines()[-1]
259      if last_line.strip() == looking_for:
260        subp.terminate()
261        return sha
262
263  if subp.returncode:
264    raise subprocess.CalledProcessError(subp.returncode, git_command)
265  raise ValueError(f'No commit with revision {rev} found')
266
267
268def translate_rev_to_sha(llvm_config: LLVMConfig, rev: Rev) -> str:
269  """Translates a Rev to a SHA.
270
271  Raises a ValueError if the given Rev doesn't exist in the given config.
272  """
273  branch, number = rev
274
275  if branch == MAIN_BRANCH:
276    if number < base_llvm_revision:
277      return translate_prebase_rev_to_sha(llvm_config, rev)
278    base_sha = base_llvm_sha
279    base_revision_number = base_llvm_revision
280  else:
281    base_sha = check_output(
282        ['git', 'merge-base', base_llvm_sha, f'{llvm_config.remote}/{branch}'],
283        cwd=llvm_config.dir,
284    )
285    base_sha = base_sha.strip()
286    if base_sha == base_llvm_sha:
287      base_revision_number = base_llvm_revision
288    else:
289      base_revision_number = translate_prebase_sha_to_rev_number(
290          llvm_config, base_sha)
291
292  # Alternatively, we could |git log --format=%H|, but git is *super* fast
293  # about rev walking/counting locally compared to long |log|s, so we walk back
294  # twice.
295  head = check_output(
296      ['git', 'rev-parse', f'{llvm_config.remote}/{branch}'],
297      cwd=llvm_config.dir,
298  )
299  branch_head_sha = head.strip()
300
301  commit_number = number - base_revision_number
302  revs_between_str = check_output(
303      [
304          'git',
305          'rev-list',
306          '--count',
307          '--first-parent',
308          f'{base_sha}..{branch_head_sha}',
309      ],
310      cwd=llvm_config.dir,
311  )
312  revs_between = int(revs_between_str.strip())
313
314  commits_behind_head = revs_between - commit_number
315  if commits_behind_head < 0:
316    raise ValueError(
317        f'Revision {rev} is past {llvm_config.remote}/{branch}. Try updating '
318        'your tree?')
319
320  result = check_output(
321      ['git', 'rev-parse', f'{branch_head_sha}~{commits_behind_head}'],
322      cwd=llvm_config.dir,
323  )
324
325  return result.strip()
326
327
328def find_root_llvm_dir(root_dir: str = '.') -> str:
329  """Finds the root of an LLVM directory starting at |root_dir|.
330
331  Raises a subprocess.CalledProcessError if no git directory is found.
332  """
333  result = check_output(
334      ['git', 'rev-parse', '--show-toplevel'],
335      cwd=root_dir,
336  )
337  return result.strip()
338
339
340def main(argv: t.List[str]) -> None:
341  parser = argparse.ArgumentParser(description=__doc__)
342  parser.add_argument(
343      '--llvm_dir',
344      help='LLVM directory to consult for git history, etc. Autodetected '
345      'if cwd is inside of an LLVM tree')
346  parser.add_argument(
347      '--upstream',
348      default='origin',
349      help="LLVM upstream's remote name. Defaults to %(default)s.")
350  sha_or_rev = parser.add_mutually_exclusive_group(required=True)
351  sha_or_rev.add_argument(
352      '--sha', help='A git SHA (or ref) to convert to a rev')
353  sha_or_rev.add_argument('--rev', help='A rev to convert into a sha')
354  opts = parser.parse_args(argv)
355
356  llvm_dir = opts.llvm_dir
357  if llvm_dir is None:
358    try:
359      llvm_dir = find_root_llvm_dir()
360    except subprocess.CalledProcessError:
361      parser.error("Couldn't autodetect an LLVM tree; please use --llvm_dir")
362
363  config = LLVMConfig(
364      remote=opts.upstream,
365      dir=opts.llvm_dir or find_root_llvm_dir(),
366  )
367
368  if opts.sha:
369    rev = translate_sha_to_rev(config, opts.sha)
370    print(rev)
371  else:
372    sha = translate_rev_to_sha(config, Rev.parse(opts.rev))
373    print(sha)
374
375
376if __name__ == '__main__':
377  main(sys.argv[1:])
378