1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Live entity inspection utilities.
16
17This module contains whatever inspect doesn't offer out of the box.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import inspect
25import itertools
26import linecache
27import sys
28import threading
29import types
30
31import six
32
33from tensorflow.python.util import tf_inspect
34
35# This lock seems to help avoid linecache concurrency errors.
36_linecache_lock = threading.Lock()
37
38
39# These functions test negative for isinstance(*, types.BuiltinFunctionType)
40# and inspect.isbuiltin, and are generally not visible in globals().
41# TODO(mdan): Remove this.
42SPECIAL_BUILTINS = {
43    'dict': dict,
44    'enumerate': enumerate,
45    'float': float,
46    'int': int,
47    'len': len,
48    'list': list,
49    'print': print,
50    'range': range,
51    'tuple': tuple,
52    'type': type,
53    'zip': zip
54}
55
56if six.PY2:
57  SPECIAL_BUILTINS['xrange'] = xrange
58
59
60def islambda(f):
61  if not tf_inspect.isfunction(f):
62    return False
63  # TODO(mdan): Look into checking the only the code object.
64  if not (hasattr(f, '__name__') and hasattr(f, '__code__')):
65    return False
66  # Some wrappers can rename the function, but changing the name of the
67  # code object is harder.
68  return (
69      (f.__name__ == '<lambda>') or (f.__code__.co_name == '<lambda>'))
70
71
72def isnamedtuple(f):
73  """Returns True if the argument is a namedtuple-like."""
74  if not (tf_inspect.isclass(f) and issubclass(f, tuple)):
75    return False
76  if not hasattr(f, '_fields'):
77    return False
78  fields = getattr(f, '_fields')
79  if not isinstance(fields, tuple):
80    return False
81  if not all(isinstance(f, str) for f in fields):
82    return False
83  return True
84
85
86def isbuiltin(f):
87  """Returns True if the argument is a built-in function."""
88  if any(f is builtin for builtin in six.moves.builtins.__dict__.values()):
89    return True
90  elif isinstance(f, types.BuiltinFunctionType):
91    return True
92  elif inspect.isbuiltin(f):
93    return True
94  elif f is eval:
95    return True
96  else:
97    return False
98
99
100def isconstructor(cls):
101  """Returns True if the argument is an object constructor.
102
103  In general, any object of type class is a constructor, with the exception
104  of classes created using a callable metaclass.
105  See below for why a callable metaclass is not a trivial combination:
106  https://docs.python.org/2.7/reference/datamodel.html#customizing-class-creation
107
108  Args:
109    cls: Any
110  Returns:
111    Bool
112  """
113  return (
114      inspect.isclass(cls)
115      and not (issubclass(cls.__class__, type)
116               and hasattr(cls.__class__, '__call__')
117               and cls.__class__.__call__ is not type.__call__))
118
119
120def _fix_linecache_record(obj):
121  """Fixes potential corruption of linecache in the presence of functools.wraps.
122
123  functools.wraps modifies the target object's __module__ field, which seems
124  to confuse linecache in special instances, for example when the source is
125  loaded from a .par file (see https://google.github.io/subpar/subpar.html).
126
127  This function simply triggers a call to linecache.updatecache when a mismatch
128  was detected between the object's __module__ property and the object's source
129  file.
130
131  Args:
132    obj: Any
133  """
134  if hasattr(obj, '__module__'):
135    obj_file = inspect.getfile(obj)
136    obj_module = obj.__module__
137
138    # A snapshot of the loaded modules helps avoid "dict changed size during
139    # iteration" errors.
140    loaded_modules = tuple(sys.modules.values())
141    for m in loaded_modules:
142      if hasattr(m, '__file__') and m.__file__ == obj_file:
143        if obj_module is not m:
144          linecache.updatecache(obj_file, m.__dict__)
145
146
147def getimmediatesource(obj):
148  """A variant of inspect.getsource that ignores the __wrapped__ property."""
149  with _linecache_lock:
150    _fix_linecache_record(obj)
151    lines, lnum = inspect.findsource(obj)
152    return ''.join(inspect.getblock(lines[lnum:]))
153
154
155def getnamespace(f):
156  """Returns the complete namespace of a function.
157
158  Namespace is defined here as the mapping of all non-local variables to values.
159  This includes the globals and the closure variables. Note that this captures
160  the entire globals collection of the function, and may contain extra symbols
161  that it does not actually use.
162
163  Args:
164    f: User defined function.
165  Returns:
166    A dict mapping symbol names to values.
167  """
168  namespace = dict(six.get_function_globals(f))
169  closure = six.get_function_closure(f)
170  freevars = six.get_function_code(f).co_freevars
171  if freevars and closure:
172    for name, cell in zip(freevars, closure):
173      try:
174        namespace[name] = cell.cell_contents
175      except ValueError:
176        # Cell contains undefined variable, omit it from the namespace.
177        pass
178  return namespace
179
180
181def getqualifiedname(namespace, object_, max_depth=5, visited=None):
182  """Returns the name by which a value can be referred to in a given namespace.
183
184  If the object defines a parent module, the function attempts to use it to
185  locate the object.
186
187  This function will recurse inside modules, but it will not search objects for
188  attributes. The recursion depth is controlled by max_depth.
189
190  Args:
191    namespace: Dict[str, Any], the namespace to search into.
192    object_: Any, the value to search.
193    max_depth: Optional[int], a limit to the recursion depth when searching
194        inside modules.
195    visited: Optional[Set[int]], ID of modules to avoid visiting.
196  Returns: Union[str, None], the fully-qualified name that resolves to the value
197      o, or None if it couldn't be found.
198  """
199  if visited is None:
200    visited = set()
201
202  # Copy the dict to avoid "changed size error" during concurrent invocations.
203  # TODO(mdan): This is on the hot path. Can we avoid the copy?
204  namespace = dict(namespace)
205
206  for name in namespace:
207    # The value may be referenced by more than one symbol, case in which
208    # any symbol will be fine. If the program contains symbol aliases that
209    # change over time, this may capture a symbol that will later point to
210    # something else.
211    # TODO(mdan): Prefer the symbol that matches the value type name.
212    if object_ is namespace[name]:
213      return name
214
215  # If an object is not found, try to search its parent modules.
216  parent = tf_inspect.getmodule(object_)
217  if (parent is not None and parent is not object_ and
218      parent is not namespace):
219    # No limit to recursion depth because of the guard above.
220    parent_name = getqualifiedname(
221        namespace, parent, max_depth=0, visited=visited)
222    if parent_name is not None:
223      name_in_parent = getqualifiedname(
224          parent.__dict__, object_, max_depth=0, visited=visited)
225      assert name_in_parent is not None, (
226          'An object should always be found in its owner module')
227      return '{}.{}'.format(parent_name, name_in_parent)
228
229  if max_depth:
230    # Iterating over a copy prevents "changed size due to iteration" errors.
231    # It's unclear why those occur - suspecting new modules may load during
232    # iteration.
233    for name in namespace.keys():
234      value = namespace[name]
235      if tf_inspect.ismodule(value) and id(value) not in visited:
236        visited.add(id(value))
237        name_in_module = getqualifiedname(value.__dict__, object_,
238                                          max_depth - 1, visited)
239        if name_in_module is not None:
240          return '{}.{}'.format(name, name_in_module)
241  return None
242
243
244def _get_unbound_function(m):
245  # TODO(mdan): Figure out why six.get_unbound_function fails in some cases.
246  # The failure case is for tf.keras.Model.
247  if hasattr(m, '__func__'):
248    return m.__func__
249  if hasattr(m, 'im_func'):
250    return m.im_func
251  return m
252
253
254def getdefiningclass(m, owner_class):
255  """Resolves the class (e.g. one of the superclasses) that defined a method."""
256  # Normalize bound functions to their respective unbound versions.
257  m = _get_unbound_function(m)
258  for superclass in reversed(inspect.getmro(owner_class)):
259    if hasattr(superclass, m.__name__):
260      superclass_m = getattr(superclass, m.__name__)
261      if _get_unbound_function(superclass_m) is m:
262        return superclass
263      elif hasattr(m, '__self__') and m.__self__ == owner_class:
264        # Python 3 class methods only work this way it seems :S
265        return superclass
266  return owner_class
267
268
269def getmethodclass(m):
270  """Resolves a function's owner, e.g. a method's class.
271
272  Note that this returns the object that the function was retrieved from, not
273  necessarily the class where it was defined.
274
275  This function relies on Python stack frame support in the interpreter, and
276  has the same limitations that inspect.currentframe.
277
278  Limitations. This function will only work correctly if the owned class is
279  visible in the caller's global or local variables.
280
281  Args:
282    m: A user defined function
283
284  Returns:
285    The class that this function was retrieved from, or None if the function
286    is not an object or class method, or the class that owns the object or
287    method is not visible to m.
288
289  Raises:
290    ValueError: if the class could not be resolved for any unexpected reason.
291  """
292
293  # Callable objects: return their own class.
294  if (not hasattr(m, '__name__') and hasattr(m, '__class__') and
295      hasattr(m, '__call__')):
296    if isinstance(m.__class__, six.class_types):
297      return m.__class__
298
299  # Instance and class: return the class of "self".
300  m_self = getattr(m, '__self__', None)
301  if m_self is not None:
302    if inspect.isclass(m_self):
303      return m_self
304    return m_self.__class__
305
306  # Class, static and unbound methods: search all defined classes in any
307  # namespace. This is inefficient but more robust a method.
308  owners = []
309  caller_frame = tf_inspect.currentframe().f_back
310  try:
311    # TODO(mdan): This doesn't consider cell variables.
312    # TODO(mdan): This won't work if the owner is hidden inside a container.
313    # Cell variables may be pulled using co_freevars and the closure.
314    for v in itertools.chain(caller_frame.f_locals.values(),
315                             caller_frame.f_globals.values()):
316      if hasattr(v, m.__name__):
317        candidate = getattr(v, m.__name__)
318        # Py2 methods may be bound or unbound, extract im_func to get the
319        # underlying function.
320        if hasattr(candidate, 'im_func'):
321          candidate = candidate.im_func
322        if hasattr(m, 'im_func'):
323          m = m.im_func
324        if candidate is m:
325          owners.append(v)
326  finally:
327    del caller_frame
328
329  if owners:
330    if len(owners) == 1:
331      return owners[0]
332
333    # If multiple owners are found, and are not subclasses, raise an error.
334    owner_types = tuple(o if tf_inspect.isclass(o) else type(o) for o in owners)
335    for o in owner_types:
336      if tf_inspect.isclass(o) and issubclass(o, tuple(owner_types)):
337        return o
338    raise ValueError('Found too many owners of %s: %s' % (m, owners))
339
340  return None
341
342
343def getfutureimports(entity):
344  """Detects what future imports are necessary to safely execute entity source.
345
346  Args:
347    entity: Any object
348
349  Returns:
350    A tuple of future strings
351  """
352  if not (tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity)):
353    return tuple()
354  return tuple(sorted(name for name, value in entity.__globals__.items()
355                      if getattr(value, '__module__', None) == '__future__'))
356