102 lines
3.4 KiB
Python
102 lines
3.4 KiB
Python
_HAS_OPS = False
|
|
|
|
|
|
def _has_ops():
|
|
return False
|
|
|
|
|
|
def _register_extensions():
|
|
import os
|
|
import importlib
|
|
import torch
|
|
|
|
# load the custom_op_library and register the custom ops
|
|
lib_dir = os.path.dirname(__file__)
|
|
if os.name == 'nt':
|
|
# Register the main torchvision library location on the default DLL path
|
|
import ctypes
|
|
import sys
|
|
|
|
kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True)
|
|
with_load_library_flags = hasattr(kernel32, 'AddDllDirectory')
|
|
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
|
|
|
if with_load_library_flags:
|
|
kernel32.AddDllDirectory.restype = ctypes.c_void_p
|
|
|
|
if sys.version_info >= (3, 8):
|
|
os.add_dll_directory(lib_dir)
|
|
elif with_load_library_flags:
|
|
res = kernel32.AddDllDirectory(lib_dir)
|
|
if res is None:
|
|
err = ctypes.WinError(ctypes.get_last_error())
|
|
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
|
|
raise err
|
|
|
|
kernel32.SetErrorMode(prev_error_mode)
|
|
|
|
loader_details = (
|
|
importlib.machinery.ExtensionFileLoader,
|
|
importlib.machinery.EXTENSION_SUFFIXES
|
|
)
|
|
|
|
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
|
|
ext_specs = extfinder.find_spec("_C")
|
|
if ext_specs is None:
|
|
raise ImportError
|
|
torch.ops.load_library(ext_specs.origin)
|
|
|
|
|
|
try:
|
|
_register_extensions()
|
|
_HAS_OPS = True
|
|
|
|
def _has_ops(): # noqa: F811
|
|
return True
|
|
except (ImportError, OSError):
|
|
pass
|
|
|
|
|
|
def _assert_has_ops():
|
|
if not _has_ops():
|
|
raise RuntimeError(
|
|
"Couldn't load custom C++ ops. This can happen if your PyTorch and "
|
|
"torchvision versions are incompatible, or if you had errors while compiling "
|
|
"torchvision from source. For further information on the compatible versions, check "
|
|
"https://github.com/pytorch/vision#installation for the compatibility matrix. "
|
|
"Please check your PyTorch version with torch.__version__ and your torchvision "
|
|
"version with torchvision.__version__ and verify if they are compatible, and if not "
|
|
"please reinstall torchvision so that it matches your PyTorch install."
|
|
)
|
|
|
|
|
|
def _check_cuda_version():
|
|
"""
|
|
Make sure that CUDA versions match between the pytorch install and torchvision install
|
|
"""
|
|
if not _HAS_OPS:
|
|
return -1
|
|
import torch
|
|
_version = torch.ops.torchvision._cuda_version()
|
|
if _version != -1 and torch.version.cuda is not None:
|
|
tv_version = str(_version)
|
|
if int(tv_version) < 10000:
|
|
tv_major = int(tv_version[0])
|
|
tv_minor = int(tv_version[2])
|
|
else:
|
|
tv_major = int(tv_version[0:2])
|
|
tv_minor = int(tv_version[3])
|
|
t_version = torch.version.cuda
|
|
t_version = t_version.split('.')
|
|
t_major = int(t_version[0])
|
|
t_minor = int(t_version[1])
|
|
if t_major != tv_major or t_minor != tv_minor:
|
|
raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
|
|
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
|
|
"Please reinstall the torchvision that matches your PyTorch install."
|
|
.format(t_major, t_minor, tv_major, tv_minor))
|
|
return _version
|
|
|
|
|
|
_check_cuda_version()
|