1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5import os
6import platform
7
8_is_windows = platform.system() == "Windows"
9_this_directory = os.path.dirname(__file__)
10
11# The standard LLVM build/install tree for Windows is laid out as:
12#   bin/
13#     MLIRPublicAPI.dll
14#   python/
15#     _mlir.*.pyd (dll extension)
16#     mlir/
17#       _dlloader.py (this file)
18# First check the python/ directory level for DLLs co-located with the pyd
19# file, and then fall back to searching the bin/ directory.
20# TODO: This should be configurable at some point.
21_dll_search_path = [
22  os.path.join(_this_directory, ".."),
23  os.path.join(_this_directory, "..", "..", "bin"),
24]
25
26# Stash loaded DLLs to keep them alive.
27_loaded_dlls = []
28
29def preload_dependency(public_name):
30  """Preloads a dylib by its soname or DLL name.
31
32  On Windows and Linux, doing this prior to loading a dependency will populate
33  the library in the flat namespace so that a subsequent library that depend
34  on it will resolve to this preloaded version.
35
36  On OSX, resolution is completely path based so this facility no-ops. On
37  Linux, as long as RPATHs are setup properly, resolution is path based but
38  this facility can still act as an escape hatch for relocatable distributions.
39  """
40  if _is_windows:
41    _preload_dependency_windows(public_name)
42
43
44def _preload_dependency_windows(public_name):
45  dll_basename = public_name + ".dll"
46  found_path = None
47  for search_dir in _dll_search_path:
48    candidate_path = os.path.join(search_dir, dll_basename)
49    if os.path.exists(candidate_path):
50      found_path = candidate_path
51      break
52
53  if found_path is None:
54    raise RuntimeError(
55      f"Unable to find dependency DLL {dll_basename} in search "
56      f"path {_dll_search_path}")
57
58  import ctypes
59  _loaded_dlls.append(ctypes.CDLL(found_path))
60