1# Copyright 2021 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Common Utils."""
16
17# pylint: disable=g-importing-member
18from dataclasses import dataclass
19import io
20from pathlib import Path
21from pathlib import PurePath
22import sys
23from typing import List
24from typing import Set
25
26# pylint: disable=g-import-not-at-top
27try:
28  from git import Blob
29  from git import Commit
30  from git import Tree
31except ModuleNotFoundError:
32  print(
33      'ERROR: Please install GitPython by `pip3 install GitPython`.',
34      file=sys.stderr)
35  exit(1)
36
37THIS_DIR = Path(__file__).resolve().parent
38LIBCORE_DIR = THIS_DIR.parent.parent.resolve()
39
40UPSTREAM_CLASS_PATHS = [
41    'jdk/src/share/classes/',
42    'src/java.base/share/classes/',
43    'src/java.base/linux/classes/',
44    'src/java.base/unix/classes/',
45    'src/java.sql/share/classes/',
46    'src/java.logging/share/classes/',
47    'src/java.prefs/share/classes/',
48    'src/java.net/share/classes/',
49]
50
51UPSTREAM_TEST_PATHS = [
52    'jdk/test/',
53    'test/jdk/',
54]
55
56UPSTREAM_SEARCH_PATHS = UPSTREAM_CLASS_PATHS + UPSTREAM_TEST_PATHS
57
58OJLUNI_JAVA_BASE_PATH = 'ojluni/src/main/java/'
59OJLUNI_TEST_PATH = 'ojluni/src/'
60TEST_PATH = OJLUNI_TEST_PATH + 'test/'
61
62
63@dataclass
64class ExpectedUpstreamEntry:
65  """A map entry in the EXPECTED_UPSTREAM file."""
66  dst_path: str  # destination path
67  git_ref: str  # a git reference to an upstream commit
68  src_path: str  # source path in the commit pointed by the git_ref
69  comment_lines: str = ''  # The comment lines above the entry line
70
71  def __eq__(self, other):
72    if not isinstance(other, ExpectedUpstreamEntry):
73      return False
74
75    return (self.dst_path == other.dst_path and
76            self.git_ref == other.git_ref and
77            self.src_path == other.src_path and
78            self.comment_lines == other.comment_lines)
79
80
81class ExpectedUpstreamFile:
82  """A file object representing the EXPECTED_UPSTREAM file."""
83
84  def __init__(self, file_or_bytes=LIBCORE_DIR / 'EXPECTED_UPSTREAM'):
85    if isinstance(file_or_bytes, Path):
86      path = Path(file_or_bytes)
87      # pylint: disable=unnecessary-lambda
88      self.openable = lambda mode: path.open(mode)
89    elif isinstance(file_or_bytes, bytes):
90      self.openable = lambda mode: io.StringIO(file_or_bytes.decode('utf-8'))
91    else:
92      raise NotImplementedError('Only support bytes or Path type')
93
94  def read_all_entries(self) -> List[ExpectedUpstreamEntry]:
95    """Read all entries from the file."""
96    result: List[ExpectedUpstreamEntry] = []
97    with self.openable('r') as file:
98      comment_lines = ''  # Store the comment lines in the next entry
99      for line in file:
100        stripped = line.strip()
101        # Ignore empty lines and comments starting with '#'
102        if not stripped or stripped.startswith('#'):
103          comment_lines += line
104          continue
105
106        entry = self.parse_line(stripped, comment_lines)
107        result.append(entry)
108        comment_lines = ''
109
110    return result
111
112  def write_all_entries(self, entries: List[ExpectedUpstreamEntry]) -> None:
113    """Write all entries into the file."""
114    with self.openable('w') as file:
115      for e in entries:
116        file.write(e.comment_lines)
117        file.write(','.join([e.dst_path, e.git_ref, e.src_path]))
118        file.write('\n')
119
120  def write_new_entry(self, entry: ExpectedUpstreamEntry,
121                      entries: List[ExpectedUpstreamEntry] = None) -> None:
122    if entries is None:
123      entries = self.read_all_entries()
124
125    entries.append(entry)
126    self.sort_and_write_all_entries(entries)
127
128  def sort_and_write_all_entries(self,
129                                 entries: List[ExpectedUpstreamEntry]) -> None:
130    header = entries[0].comment_lines
131    entries[0].comment_lines = ''
132    entries.sort(key=lambda e: e.dst_path)
133    # Keep the header above the first entry
134    entries[0].comment_lines = header + entries[0].comment_lines
135    self.write_all_entries(entries)
136
137  def get_new_or_modified_entries(self, other) -> List[ExpectedUpstreamEntry]:
138    r"""Return a list of modified and added entries from the other file.
139
140    Args:
141      other: the other file
142
143    Returns:
144      A list of modified and added entries
145    """
146    result: List[ExpectedUpstreamEntry] = []
147    this_entries = self.read_all_entries()
148    that_entries = other.read_all_entries()
149    this_map = {}
150    for e in this_entries:
151      this_map[e.dst_path] = e
152
153    for e in that_entries:
154      value = this_map.get(e.dst_path)
155      if value is None or value != e:
156        result.append(e)
157
158    return result
159
160  def get_removed_paths(self, other: ExpectedUpstreamEntry) -> Set[str]:
161    r"""Returns a list of dst paths removed in the new list.
162
163    Args:
164      other: the other file
165
166    Returns:
167      A list of removed paths
168    """
169    this_paths = list(map(lambda e: e.dst_path, self.read_all_entries()))
170    that_entries = other.read_all_entries()
171    that_map = {}
172    for e in that_entries:
173      that_map[e.dst_path] = e
174
175    return set(filter(lambda p: p not in that_map, this_paths))
176
177  @staticmethod
178  def parse_line(line: str, comment_lines: str) -> ExpectedUpstreamEntry:
179    items = line.split(',')
180    size = len(items)
181    if size != 3:
182      raise ValueError(
183          f"The size must be 3, but is {size}. The line is '{line}'")
184
185    return ExpectedUpstreamEntry(items[0], items[1], items[2], comment_lines)
186
187
188class OjluniFinder:
189  """Finder for java classes or ojluni/ paths."""
190
191  def __init__(self, existing_paths: List[str]):
192    self.existing_paths = existing_paths
193
194  @staticmethod
195  def translate_ojluni_path_to_class_name(path: str) -> str:
196    r"""Translate an Ojluni file path to full class name.
197
198    Args:
199      path: ojluni path
200
201    Returns:
202      class name or None if class name isn't found.
203    """
204
205    if not path.endswith('.java'):
206      return None
207
208    if path.startswith(OJLUNI_JAVA_BASE_PATH):
209      base_path = OJLUNI_JAVA_BASE_PATH
210    elif path.startswith(OJLUNI_TEST_PATH):
211      base_path = OJLUNI_TEST_PATH
212    else:
213      return None
214
215    base_len = len(base_path)
216    return path[base_len:-5].replace('/', '.')
217
218  @staticmethod
219  def translate_from_class_name_to_ojluni_path(class_or_path: str) -> str:
220    """Returns a ojluni path from a class name."""
221    # if it contains '/', then it's a path
222    if '/' in class_or_path:
223      return class_or_path
224
225    base_path = OJLUNI_TEST_PATH if class_or_path.startswith(
226        'test.') else OJLUNI_JAVA_BASE_PATH
227
228    relative_path = class_or_path.replace('.', '/')
229    return f'{base_path}{relative_path}.java'
230
231  def match_path_prefix(self, input_path: str) -> Set[str]:
232    """Returns a set of existing file paths matching the given partial path."""
233    path_matches = list(
234        filter(lambda path: path.startswith(input_path), self.existing_paths))
235    result_set: Set[str] = set()
236    # if it's found, just return the result
237    if input_path in path_matches:
238      result_set.add(input_path)
239    else:
240      input_ojluni_path = PurePath(input_path)
241      # the input ends with '/', the autocompletion result contain the children
242      # instead of the matching the prefix in its parent directory
243      input_path_parent_or_self = input_ojluni_path
244      if not input_path.endswith('/'):
245        input_path_parent_or_self = input_path_parent_or_self.parent
246      n_parts = len(input_path_parent_or_self.parts)
247      for match in path_matches:
248        path = PurePath(match)
249        # path.parts[n_parts] should not exceed the index and should be
250        # a valid child path because input_path_parent_or_self must be a
251        # valid directory
252        child = list(path.parts)[n_parts]
253        result = (input_path_parent_or_self / child).as_posix()
254        # if result is not exact, the result represents a directory.
255        if result != match:
256          result += '/'
257        result_set.add(result)
258
259    return result_set
260
261  def match_classname_prefix(self, input_class_name: str) -> List[str]:
262    """Returns a list of package / class names given the partial class name."""
263    # If '/' exists, it's probably a path, not a partial class name
264    if '/' in input_class_name:
265      return []
266
267    result_list = []
268    partial_relative_path = input_class_name.replace('.', '/')
269    for base_path in [OJLUNI_JAVA_BASE_PATH, OJLUNI_TEST_PATH]:
270      partial_ojluni_path = base_path + partial_relative_path
271      result_paths = self.match_path_prefix(partial_ojluni_path)
272      # pylint: disable=cell-var-from-loop
273      result_list.extend(
274          map(lambda path: convert_path_to_java_class_name(path, base_path),
275              list(result_paths)))
276
277    return result_list
278
279
280class OpenjdkFinder:
281  """Finder for java classes or paths in a upstream OpenJDK commit."""
282
283  def __init__(self, commit: Commit):
284    self.commit = commit
285
286  @staticmethod
287  def translate_src_path_to_ojluni_path(src_path: str) -> str:
288    """Returns None if src_path isn't in a known source directory."""
289    relative_path = None
290    for base_path in UPSTREAM_TEST_PATHS:
291      if src_path.startswith(base_path):
292        length = len(base_path)
293        relative_path = src_path[length:]
294        break
295
296    if relative_path:
297      return f'{OJLUNI_TEST_PATH}test/{relative_path}'
298
299    for base_path in UPSTREAM_CLASS_PATHS:
300      if src_path.startswith(base_path):
301        length = len(base_path)
302        relative_path = src_path[length:]
303        break
304
305    if relative_path:
306      return f'{OJLUNI_JAVA_BASE_PATH}{relative_path}'
307
308    return None
309
310  def find_src_path_from_classname(self, class_or_path: str) -> str:
311    """Finds a valid source path given a valid class name or path."""
312    # if it contains '/', then it's a path
313    if '/' in class_or_path:
314      if self.has_file(class_or_path):
315        return class_or_path
316      else:
317        return None
318
319    relative_path = class_or_path.replace('.', '/')
320    src_path = None
321    for base_path in UPSTREAM_SEARCH_PATHS:
322      full_path = f'{base_path}{relative_path}.java'
323      if self.has_file(full_path):
324        src_path = full_path
325        break
326
327    return src_path
328
329  def get_search_paths(self) -> List[str]:
330    return UPSTREAM_SEARCH_PATHS
331
332  def find_src_path_from_ojluni_path(self, ojluni_path: str) -> str:
333    """Returns a source path that guessed from the ojluni_path."""
334    base_paths = None
335    relative_path = None
336
337    if ojluni_path.startswith(OJLUNI_JAVA_BASE_PATH):
338      base_paths = UPSTREAM_CLASS_PATHS
339      length = len(OJLUNI_JAVA_BASE_PATH)
340      relative_path = ojluni_path[length:]
341    elif ojluni_path.startswith(TEST_PATH):
342      base_paths = UPSTREAM_TEST_PATHS
343      length = len(TEST_PATH)
344      relative_path = ojluni_path[length:]
345    else:
346      return None
347
348    for base_path in base_paths:
349      full_path = base_path + relative_path
350      if self.has_file(full_path):
351        return full_path
352
353    return None
354
355  def match_path_prefix(self, input_path: str) -> List[str]:
356    """Returns a list of source paths matching the given partial string."""
357    result_list = []
358
359    search_tree = self.commit.tree
360    path_obj = PurePath(input_path)
361    is_exact = self.has_file(path_obj.as_posix())
362    is_directory_path = input_path.endswith('/')
363    exact_obj = search_tree[path_obj.as_posix()] if is_exact else None
364    search_word = ''
365    if is_exact and isinstance(exact_obj, Blob):
366      # an exact file path
367      result_list.append(input_path)
368      return result_list
369    elif is_directory_path:
370      # an exact directory path and can't be a prefix directory name.
371      if is_exact:
372        search_tree = exact_obj
373      else:
374        # Such path doesn't exist, and thus returns empty list
375        return result_list
376    elif len(path_obj.parts) >= 2 and not is_directory_path:
377      parent_path = path_obj.parent.as_posix()
378      if self.has_file(parent_path):
379        search_tree = search_tree[parent_path]
380        search_word = path_obj.name
381      else:
382        # Return empty list because no such path is found
383        return result_list
384    else:
385      search_word = input_path
386
387    for tree in search_tree.trees:
388      tree_path = PurePath(tree.path)
389      if tree_path.name.startswith(search_word):
390        # Append '/' to indicate directory type. If the result has this item
391        # only, shell should auto-fill the input, and thus
392        # next tabbing in shell should fall into the above condition
393        # `is_exact and input_path.endswith('/')` and will search in the child
394        # tree.
395        result_path = tree.path + '/'
396        result_list.append(result_path)
397
398    for blob in search_tree.blobs:
399      blob_path = PurePath(blob.path)
400      if blob_path.name.startswith(search_word):
401        result_list.append(blob.path)
402
403    return result_list
404
405  def match_classname_prefix(self, input_class_name: str) -> List[str]:
406    """Return a list of package / class names from given commit and input."""
407    # If '/' exists, it's probably a path, not a class name.
408    if '/' in input_class_name:
409      return []
410
411    result_list = []
412    for base_path in UPSTREAM_SEARCH_PATHS:
413      base_len = len(base_path)
414      path = base_path + input_class_name.replace('.', '/')
415      path_results = self.match_path_prefix(path)
416      for p in path_results:
417        relative_path = p[base_len:]
418        if relative_path.endswith('.java'):
419          relative_path = relative_path[0:-5]
420        result_list.append(relative_path.replace('/', '.'))
421
422    return result_list
423
424  def has_file(self, path: str) -> bool:
425    """Returns True if the directory / file exists in the tree."""
426    return has_file_in_tree(path, self.commit.tree)
427
428
429def convert_path_to_java_class_name(path: str, base_path: str) -> str:
430  base_len = len(base_path)
431  result = path[base_len:]
432  if result.endswith('.java'):
433    result = result[0:-5]
434  result = result.replace('/', '.')
435  return result
436
437
438def has_file_in_tree(path: str, tree: Tree) -> bool:
439  """Returns True if the directory / file exists in the tree."""
440  try:
441    # pylint: disable=pointless-statement
442    tree[path]
443    return True
444  except KeyError:
445    return False
446