1# Copyright 2012 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import fnmatch
6import inspect
7import os
8import re
9import sys
10
11from telemetry import decorators
12from telemetry.internal.util import camel_case
13from telemetry.internal.util import classes as classes_module
14
15
16@decorators.Cache
17def DiscoverModules(start_dir, top_level_dir, pattern='*'):
18  """Discover all modules in |start_dir| which match |pattern|.
19
20  Args:
21    start_dir: The directory to recursively search.
22    top_level_dir: The top level of the package, for importing.
23    pattern: Unix shell-style pattern for filtering the filenames to import.
24
25  Returns:
26    list of modules.
27  """
28  # start_dir and top_level_dir must be consistent with each other.
29  start_dir = os.path.realpath(start_dir)
30  top_level_dir = os.path.realpath(top_level_dir)
31
32  modules = []
33  for dir_path, _, filenames in os.walk(start_dir):
34    for filename in filenames:
35      # Filter out unwanted filenames.
36      if filename.startswith('.') or filename.startswith('_'):
37        continue
38      if os.path.splitext(filename)[1] != '.py':
39        continue
40      if not fnmatch.fnmatch(filename, pattern):
41        continue
42
43      # Find the module.
44      module_rel_path = os.path.relpath(
45          os.path.join(dir_path, filename), top_level_dir)
46      module_name = re.sub(r'[/\\]', '.', os.path.splitext(module_rel_path)[0])
47
48      # Import the module.
49      try:
50        # Make sure that top_level_dir is the first path in the sys.path in case
51        # there are naming conflict in module parts.
52        original_sys_path = sys.path[:]
53        sys.path.insert(0, top_level_dir)
54        module = __import__(module_name, fromlist=[True])
55        modules.append(module)
56      finally:
57        sys.path = original_sys_path
58  return modules
59
60
61# TODO(dtu): Normalize all discoverable classes to have corresponding module
62# and class names, then always index by class name.
63@decorators.Cache
64def DiscoverClasses(start_dir,
65                    top_level_dir,
66                    base_class,
67                    pattern='*',
68                    index_by_class_name=True,
69                    directly_constructable=False):
70  """Discover all classes in |start_dir| which subclass |base_class|.
71
72  Base classes that contain subclasses are ignored by default.
73
74  Args:
75    start_dir: The directory to recursively search.
76    top_level_dir: The top level of the package, for importing.
77    base_class: The base class to search for.
78    pattern: Unix shell-style pattern for filtering the filenames to import.
79    index_by_class_name: If True, use class name converted to
80        lowercase_with_underscores instead of module name in return dict keys.
81    directly_constructable: If True, will only return classes that can be
82        constructed without arguments
83
84  Returns:
85    dict of {module_name: class} or {underscored_class_name: class}
86  """
87  modules = DiscoverModules(start_dir, top_level_dir, pattern)
88  classes = {}
89  for module in modules:
90    new_classes = DiscoverClassesInModule(
91        module, base_class, index_by_class_name, directly_constructable)
92    classes = dict(classes.items() + new_classes.items())
93  return classes
94
95
96@decorators.Cache
97def DiscoverClassesInModule(module,
98                            base_class,
99                            index_by_class_name=False,
100                            directly_constructable=False):
101  """Discover all classes in |module| which subclass |base_class|.
102
103  Base classes that contain subclasses are ignored by default.
104
105  Args:
106    module: The module to search.
107    base_class: The base class to search for.
108    index_by_class_name: If True, use class name converted to
109        lowercase_with_underscores instead of module name in return dict keys.
110
111  Returns:
112    dict of {module_name: class} or {underscored_class_name: class}
113  """
114  classes = {}
115  for _, obj in inspect.getmembers(module):
116    # Ensure object is a class.
117    if not inspect.isclass(obj):
118      continue
119    # Include only subclasses of base_class.
120    if not issubclass(obj, base_class):
121      continue
122    # Exclude the base_class itself.
123    if obj is base_class:
124      continue
125    # Exclude protected or private classes.
126    if obj.__name__.startswith('_'):
127      continue
128    # Include only the module in which the class is defined.
129    # If a class is imported by another module, exclude those duplicates.
130    if obj.__module__ != module.__name__:
131      continue
132
133    if index_by_class_name:
134      key_name = camel_case.ToUnderscore(obj.__name__)
135    else:
136      key_name = module.__name__.split('.')[-1]
137    if (not directly_constructable or
138        classes_module.IsDirectlyConstructable(obj)):
139      classes[key_name] = obj
140
141  return classes
142
143
144_counter = [0]
145
146
147def _GetUniqueModuleName():
148  _counter[0] += 1
149  return "module_" + str(_counter[0])
150