1#!/usr/bin/env python3
2
3#
4# Copyright (C) 2018 The Android Open Source Project
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#      http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18
19"""A command line utility to pull multiple change lists from Gerrit."""
20
21from __future__ import print_function
22
23import argparse
24import collections
25import itertools
26import json
27import multiprocessing
28import os
29import os.path
30import re
31import sys
32import xml.dom.minidom
33
34from gerrit import (
35    add_common_parse_args, create_url_opener_from_args, find_gerrit_name,
36    normalize_gerrit_name, query_change_lists, run
37)
38from subprocess import PIPE
39
40try:
41    # pylint: disable=redefined-builtin
42    from __builtin__ import raw_input as input  # PY2
43except ImportError:
44    pass
45
46try:
47    from shlex import quote as _sh_quote  # PY3.3
48except ImportError:
49    # Shell language simple string pattern.  If a string matches this pattern,
50    # it doesn't have to be quoted.
51    _SHELL_SIMPLE_PATTERN = re.compile('^[a-zA-Z90-9_./-]+$')
52
53    def _sh_quote(txt):
54        """Quote a string if it contains special characters."""
55        return txt if _SHELL_SIMPLE_PATTERN.match(txt) else json.dumps(txt)
56
57
58if bytes is str:
59    def write_bytes(data, file):  # PY2
60        """Write bytes to a file."""
61        # pylint: disable=redefined-builtin
62        file.write(data)
63else:
64    def write_bytes(data, file):  # PY3
65        """Write bytes to a file."""
66        # pylint: disable=redefined-builtin
67        file.buffer.write(data)
68
69
70def _confirm(question, default, file=sys.stderr):
71    """Prompt a yes/no question and convert the answer to a boolean value."""
72    # pylint: disable=redefined-builtin
73    answers = {'': default, 'y': True, 'yes': True, 'n': False, 'no': False}
74    suffix = '[Y/n] ' if default else ' [y/N] '
75    while True:
76        file.write(question + suffix)
77        file.flush()
78        ans = answers.get(input().lower())
79        if ans is not None:
80            return ans
81
82
83class ChangeList(object):
84    """A ChangeList to be checked out."""
85    # pylint: disable=too-few-public-methods,too-many-instance-attributes
86
87    def __init__(self, project, fetch, commit_sha1, commit, change_list):
88        """Initialize a ChangeList instance."""
89        # pylint: disable=too-many-arguments
90
91        self.project = project
92        self.number = change_list['_number']
93        self.branch = change_list['branch']
94
95        self.fetch = fetch
96
97        fetch_git = None
98        for protocol in ('http', 'sso', 'rpc', 'ssh'):
99            fetch_git = fetch.get(protocol)
100            if fetch_git:
101                break
102
103        if not fetch_git:
104            raise ValueError(
105                'unknown fetch protocols: ' + str(list(fetch.keys())))
106
107        self.fetch_url = fetch_git['url']
108        self.fetch_ref = fetch_git['ref']
109
110        self.commit_sha1 = commit_sha1
111        self.commit = commit
112        self.parents = commit['parents']
113
114        self.change_list = change_list
115
116
117    def is_merge(self):
118        """Check whether this change list a merge commit."""
119        return len(self.parents) > 1
120
121
122def find_repo_top(curdir):
123    """Find the top directory for this git-repo source tree."""
124    olddir = None
125    while curdir != olddir:
126        if os.path.exists(os.path.join(curdir, '.repo')):
127            return curdir
128        olddir = curdir
129        curdir = os.path.dirname(curdir)
130    raise ValueError('.repo dir not found')
131
132
133class ProjectNameDirDict:
134    """A dict which maps project name and revision to the source path."""
135    def __init__(self):
136        self._dirs = dict()
137
138
139    def add_directory(self, name, revision, path):
140        """Maps project name and revision to path."""
141        self._dirs[name] = path or name
142        if revision:
143            self._dirs[(name, revision)] = path or name
144
145
146    def find_directory(self, name, revision, default_result=None):
147        """Finds corresponding path of project name and revision."""
148        if (name, revision) in self._dirs:
149            return self._dirs[(name, revision)]
150        if default_result is None:
151            return self._dirs[name]
152        return self._dirs.get(name, default_result)
153
154
155def build_project_name_dir_dict(manifest_name):
156    """Build the mapping from Gerrit project name to source tree project
157    directory path."""
158    manifest_cmd = ['repo', 'manifest']
159    if manifest_name:
160        manifest_cmd.extend(['-m', manifest_name])
161    raw_manifest_xml = run(manifest_cmd, stdout=PIPE, check=True).stdout
162
163    manifest_xml = xml.dom.minidom.parseString(raw_manifest_xml)
164    project_dirs = ProjectNameDirDict()
165    for project in manifest_xml.getElementsByTagName('project'):
166        name = project.getAttribute('name')
167        path = project.getAttribute('path')
168        revision = project.getAttribute('revision')
169        project_dirs.add_directory(name, revision, path)
170
171    return project_dirs
172
173
174def group_and_sort_change_lists(change_lists):
175    """Build a dict that maps projects to a list of topologically sorted change
176    lists."""
177
178    # Build a dict that map projects to dicts that map commits to changes.
179    projects = collections.defaultdict(dict)
180    for change_list in change_lists:
181        commit_sha1 = None
182        for commit_sha1, value in change_list['revisions'].items():
183            fetch = value['fetch']
184            commit = value['commit']
185
186        if not commit_sha1:
187            raise ValueError('bad revision')
188
189        project = change_list['project']
190
191        project_changes = projects[project]
192        if commit_sha1 in project_changes:
193            raise KeyError('repeated commit sha1 "{}" in project "{}"'.format(
194                commit_sha1, project))
195
196        project_changes[commit_sha1] = ChangeList(
197            project, fetch, commit_sha1, commit, change_list)
198
199    # Sort all change lists in a project in post ordering.
200    def _sort_project_change_lists(changes):
201        visited_changes = set()
202        sorted_changes = []
203
204        def _post_order_traverse(change):
205            visited_changes.add(change)
206            for parent in change.parents:
207                parent_change = changes.get(parent['commit'])
208                if parent_change and parent_change not in visited_changes:
209                    _post_order_traverse(parent_change)
210            sorted_changes.append(change)
211
212        for change in sorted(changes.values(), key=lambda x: x.number):
213            if change not in visited_changes:
214                _post_order_traverse(change)
215
216        return sorted_changes
217
218    # Sort changes in each projects
219    sorted_changes = []
220    for project in sorted(projects.keys()):
221        sorted_changes.append(_sort_project_change_lists(projects[project]))
222
223    return sorted_changes
224
225
226def _main_json(args):
227    """Print the change lists in JSON format."""
228    change_lists = _get_change_lists_from_args(args)
229    json.dump(change_lists, sys.stdout, indent=4, separators=(', ', ': '))
230    print()  # Print the end-of-line
231
232
233# Git commands for merge commits
234_MERGE_COMMANDS = {
235    'merge': ['git', 'merge', '--no-edit'],
236    'merge-ff-only': ['git', 'merge', '--no-edit', '--ff-only'],
237    'merge-no-ff': ['git', 'merge', '--no-edit', '--no-ff'],
238    'reset': ['git', 'reset', '--hard'],
239    'checkout': ['git', 'checkout'],
240}
241
242
243# Git commands for non-merge commits
244_PICK_COMMANDS = {
245    'pick': ['git', 'cherry-pick', '--allow-empty'],
246    'merge': ['git', 'merge', '--no-edit'],
247    'merge-ff-only': ['git', 'merge', '--no-edit', '--ff-only'],
248    'merge-no-ff': ['git', 'merge', '--no-edit', '--no-ff'],
249    'reset': ['git', 'reset', '--hard'],
250    'checkout': ['git', 'checkout'],
251}
252
253
254def build_pull_commands(change, branch_name, merge_opt, pick_opt):
255    """Build command lines for each change.  The command lines will be passed
256    to subprocess.run()."""
257
258    cmds = []
259    if branch_name is not None:
260        cmds.append(['repo', 'start', branch_name])
261    cmds.append(['git', 'fetch', change.fetch_url, change.fetch_ref])
262    if change.is_merge():
263        cmds.append(_MERGE_COMMANDS[merge_opt] + ['FETCH_HEAD'])
264    else:
265        cmds.append(_PICK_COMMANDS[pick_opt] + ['FETCH_HEAD'])
266    return cmds
267
268
269def _sh_quote_command(cmd):
270    """Convert a command (an argument to subprocess.run()) to a shell command
271    string."""
272    return ' '.join(_sh_quote(x) for x in cmd)
273
274
275def _sh_quote_commands(cmds):
276    """Convert multiple commands (arguments to subprocess.run()) to shell
277    command strings."""
278    return ' && '.join(_sh_quote_command(cmd) for cmd in cmds)
279
280
281def _main_bash(args):
282    """Print the bash command to pull the change lists."""
283    repo_top = find_repo_top(os.getcwd())
284    project_dirs = build_project_name_dir_dict(args.manifest)
285    branch_name = _get_local_branch_name_from_args(args)
286
287    change_lists = _get_change_lists_from_args(args)
288    change_list_groups = group_and_sort_change_lists(change_lists)
289
290    print(_sh_quote_command(['pushd', repo_top]))
291    for changes in change_list_groups:
292        for change in changes:
293            project_dir = project_dirs.find_directory(
294                change.project, change.branch, change.project)
295            cmds = []
296            cmds.append(['pushd', project_dir])
297            cmds.extend(build_pull_commands(
298                change, branch_name, args.merge, args.pick))
299            cmds.append(['popd'])
300            print(_sh_quote_commands(cmds))
301    print(_sh_quote_command(['popd']))
302
303
304def _do_pull_change_lists_for_project(task, ignore_unknown_changes):
305    """Pick a list of changes (usually under a project directory)."""
306    changes, task_opts = task
307
308    branch_name = task_opts['branch_name']
309    merge_opt = task_opts['merge_opt']
310    pick_opt = task_opts['pick_opt']
311    project_dirs = task_opts['project_dirs']
312    repo_top = task_opts['repo_top']
313
314    for i, change in enumerate(changes):
315        try:
316            cwd = project_dirs.find_directory(change.project, change.branch)
317        except KeyError:
318            err_msg = 'error: project "{}" cannot be found in manifest.xml\n'
319            err_msg = err_msg.format(change.project).encode('utf-8')
320            if ignore_unknown_changes:
321                print(err_msg)
322                continue
323            return (change, changes[i + 1:], [], err_msg)
324
325        print(change.commit_sha1[0:10], i + 1, cwd)
326        cmds = build_pull_commands(change, branch_name, merge_opt, pick_opt)
327        for cmd in cmds:
328            proc = run(cmd, cwd=os.path.join(repo_top, cwd), stderr=PIPE)
329            if proc.returncode != 0:
330                return (change, changes[i + 1:], cmd, proc.stderr)
331    return None
332
333
334def _print_pull_failures(failures, file=sys.stderr):
335    """Print pull failures and tracebacks."""
336    # pylint: disable=redefined-builtin
337
338    separator = '=' * 78
339    separator_sub = '-' * 78
340
341    print(separator, file=file)
342    for failed_change, skipped_changes, cmd, errors in failures:
343        print('PROJECT:', failed_change.project, file=file)
344        print('FAILED COMMIT:', failed_change.commit_sha1, file=file)
345        for change in skipped_changes:
346            print('PENDING COMMIT:', change.commit_sha1, file=file)
347        print(separator_sub, file=sys.stderr)
348        print('FAILED COMMAND:', _sh_quote_command(cmd), file=file)
349        write_bytes(errors, file=sys.stderr)
350        print(separator, file=sys.stderr)
351
352
353def _main_pull(args):
354    """Pull the change lists."""
355    repo_top = find_repo_top(os.getcwd())
356    project_dirs = build_project_name_dir_dict(args.manifest)
357    branch_name = _get_local_branch_name_from_args(args)
358
359    # Collect change lists
360    change_lists = _get_change_lists_from_args(args)
361    change_list_groups = group_and_sort_change_lists(change_lists)
362
363    # Build the options list for tasks
364    task_opts = {
365        'branch_name': branch_name,
366        'merge_opt': args.merge,
367        'pick_opt': args.pick,
368        'project_dirs': project_dirs,
369        'repo_top': repo_top,
370    }
371
372    # Run the commands to pull the change lists
373    if args.parallel <= 1:
374        results = [_do_pull_change_lists_for_project(
375            (changes, task_opts), args.ignore_unknown_changes)
376                   for changes in change_list_groups]
377    else:
378        pool = multiprocessing.Pool(processes=args.parallel)
379        results = pool.map(_do_pull_change_lists_for_project,
380                           zip(change_list_groups, itertools.repeat(task_opts)),
381                           args.ignore_unknown_changes)
382
383    # Print failures and tracebacks
384    failures = [result for result in results if result]
385    if failures:
386        _print_pull_failures(failures)
387        sys.exit(1)
388
389
390def _parse_args():
391    """Parse command line options."""
392    parser = argparse.ArgumentParser()
393
394    parser.add_argument('command', choices=['pull', 'bash', 'json'],
395                        help='Commands')
396    add_common_parse_args(parser)
397
398
399    parser.add_argument('--manifest', help='Manifest')
400
401    parser.add_argument('-m', '--merge',
402                        choices=sorted(_MERGE_COMMANDS.keys()),
403                        default='merge-ff-only',
404                        help='Method to pull merge commits')
405
406    parser.add_argument('-p', '--pick',
407                        choices=sorted(_PICK_COMMANDS.keys()),
408                        default='pick',
409                        help='Method to pull merge commits')
410
411    parser.add_argument('-b', '--branch',
412                        help='Local branch name for `repo start`')
413
414    parser.add_argument('-j', '--parallel', default=1, type=int,
415                        help='Number of parallel running commands')
416
417    parser.add_argument('--current-branch', action='store_true',
418                        help='Pull commits to the current branch')
419
420    parser.add_argument('--ignore-unknown-changes', action='store_true',
421                        help='Ignore changes whose repo is not in the manifest')
422
423    return parser.parse_args()
424
425
426def _get_change_lists_from_args(args):
427    """Query the change lists by args."""
428    url_opener = create_url_opener_from_args(args)
429    return query_change_lists(url_opener, args.gerrit, args.query, args.start,
430                              args.limits)
431
432
433def _get_local_branch_name_from_args(args):
434    """Get the local branch name from args."""
435    if not args.branch and not args.current_branch and not _confirm(
436            'Do you want to continue without local branch name?', False):
437        print('error: `-b` or `--branch` must be specified', file=sys.stderr)
438        sys.exit(1)
439    return args.branch
440
441
442def main():
443    """Main function"""
444    args = _parse_args()
445
446    if args.gerrit:
447        args.gerrit = normalize_gerrit_name(args.gerrit)
448    else:
449        try:
450            args.gerrit = find_gerrit_name()
451        # pylint: disable=bare-except
452        except:
453            print('gerrit instance not found, use [-g GERRIT]')
454            sys.exit(1)
455
456    if args.command == 'json':
457        _main_json(args)
458    elif args.command == 'bash':
459        _main_bash(args)
460    elif args.command == 'pull':
461        _main_pull(args)
462    else:
463        raise KeyError('unknown command')
464
465if __name__ == '__main__':
466    main()
467