1# Copyright 2012 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import fnmatch
6import importlib
7import inspect
8import os
9import re
10import sys
11
12from py_utils import camel_case
13
14
15def DiscoverModules(start_dir, top_level_dir, pattern='*'):
16  """Discover all modules in |start_dir| which match |pattern|.
17
18  Args:
19    start_dir: The directory to recursively search.
20    top_level_dir: The top level of the package, for importing.
21    pattern: Unix shell-style pattern for filtering the filenames to import.
22
23  Returns:
24    list of modules.
25  """
26  # start_dir and top_level_dir must be consistent with each other.
27  start_dir = os.path.realpath(start_dir)
28  top_level_dir = os.path.realpath(top_level_dir)
29
30  modules = []
31  sub_paths = list(os.walk(start_dir))
32  # We sort the directories & file paths to ensure a deterministic ordering when
33  # traversing |top_level_dir|.
34  sub_paths.sort(key=lambda paths_tuple: paths_tuple[0])
35  for dir_path, _, filenames in sub_paths:
36    # Sort the directories to walk recursively by the directory path.
37    filenames.sort()
38    for filename in filenames:
39      # Filter out unwanted filenames.
40      if filename.startswith('.') or filename.startswith('_'):
41        continue
42      if os.path.splitext(filename)[1] != '.py':
43        continue
44      if not fnmatch.fnmatch(filename, pattern):
45        continue
46
47      # Find the module.
48      module_rel_path = os.path.relpath(
49          os.path.join(dir_path, filename), top_level_dir)
50      module_name = re.sub(r'[/\\]', '.', os.path.splitext(module_rel_path)[0])
51
52      # Import the module.
53      try:
54        # Make sure that top_level_dir is the first path in the sys.path in case
55        # there are naming conflict in module parts.
56        original_sys_path = sys.path[:]
57        sys.path.insert(0, top_level_dir)
58        module = importlib.import_module(module_name)
59        modules.append(module)
60      finally:
61        sys.path = original_sys_path
62  return modules
63
64
65def AssertNoKeyConflicts(classes_by_key_1, classes_by_key_2):
66  for k in classes_by_key_1:
67    if k in classes_by_key_2:
68      assert classes_by_key_1[k] is classes_by_key_2[k], (
69          'Found conflicting classes for the same key: '
70          'key=%s, class_1=%s, class_2=%s' % (
71              k, classes_by_key_1[k], classes_by_key_2[k]))
72
73
74# TODO(dtu): Normalize all discoverable classes to have corresponding module
75# and class names, then always index by class name.
76def DiscoverClasses(start_dir,
77                    top_level_dir,
78                    base_class,
79                    pattern='*',
80                    index_by_class_name=True,
81                    directly_constructable=False):
82  """Discover all classes in |start_dir| which subclass |base_class|.
83
84  Base classes that contain subclasses are ignored by default.
85
86  Args:
87    start_dir: The directory to recursively search.
88    top_level_dir: The top level of the package, for importing.
89    base_class: The base class to search for.
90    pattern: Unix shell-style pattern for filtering the filenames to import.
91    index_by_class_name: If True, use class name converted to
92        lowercase_with_underscores instead of module name in return dict keys.
93    directly_constructable: If True, will only return classes that can be
94        constructed without arguments
95
96  Returns:
97    dict of {module_name: class} or {underscored_class_name: class}
98  """
99  modules = DiscoverModules(start_dir, top_level_dir, pattern)
100  classes = {}
101  for module in modules:
102    new_classes = DiscoverClassesInModule(
103        module, base_class, index_by_class_name, directly_constructable)
104    # TODO(crbug.com/548652): we should remove index_by_class_name once
105    # benchmark_smoke_unittest in chromium/src/tools/perf no longer relied
106    # naming collisions to reduce the number of smoked benchmark tests.
107    if index_by_class_name:
108      AssertNoKeyConflicts(classes, new_classes)
109    classes = dict(list(classes.items()) + list(new_classes.items()))
110  return classes
111
112
113# TODO(crbug.com/548652): we should remove index_by_class_name once
114# benchmark_smoke_unittest in chromium/src/tools/perf no longer relied
115# naming collisions to reduce the number of smoked benchmark tests.
116def DiscoverClassesInModule(module,
117                            base_class,
118                            index_by_class_name=False,
119                            directly_constructable=False):
120  """Discover all classes in |module| which subclass |base_class|.
121
122  Base classes that contain subclasses are ignored by default.
123
124  Args:
125    module: The module to search.
126    base_class: The base class to search for.
127    index_by_class_name: If True, use class name converted to
128        lowercase_with_underscores instead of module name in return dict keys.
129
130  Returns:
131    dict of {module_name: class} or {underscored_class_name: class}
132  """
133  classes = {}
134  for _, obj in inspect.getmembers(module):
135    # Ensure object is a class.
136    if not inspect.isclass(obj):
137      continue
138    # Include only subclasses of base_class.
139    if not issubclass(obj, base_class):
140      continue
141    # Exclude the base_class itself.
142    if obj is base_class:
143      continue
144    # Exclude protected or private classes.
145    if obj.__name__.startswith('_'):
146      continue
147    # Include only the module in which the class is defined.
148    # If a class is imported by another module, exclude those duplicates.
149    if obj.__module__ != module.__name__:
150      continue
151
152    if index_by_class_name:
153      key_name = camel_case.ToUnderscore(obj.__name__)
154    else:
155      key_name = module.__name__.split('.')[-1]
156    if not directly_constructable or IsDirectlyConstructable(obj):
157      if key_name in classes and index_by_class_name:
158        assert classes[key_name] is obj, (
159            'Duplicate key_name with different objs detected: '
160            'key=%s, obj1=%s, obj2=%s' % (key_name, classes[key_name], obj))
161      else:
162        classes[key_name] = obj
163
164  return classes
165
166
167def IsDirectlyConstructable(cls):
168  """Returns True if instance of |cls| can be construct without arguments."""
169  assert inspect.isclass(cls)
170  if not hasattr(cls, '__init__'):
171    # Case |class A: pass|.
172    return True
173  if cls.__init__ is object.__init__:
174    # Case |class A(object): pass|.
175    return True
176  # Case |class (object):| with |__init__| other than |object.__init__|.
177  args, _, _, defaults = inspect.getargspec(cls.__init__)
178  if defaults is None:
179    defaults = ()
180  # Return true if |self| is only arg without a default.
181  return len(args) == len(defaults) + 1
182
183
184_COUNTER = [0]
185
186
187def _GetUniqueModuleName():
188  _COUNTER[0] += 1
189  return "module_" + str(_COUNTER[0])
190