181 lines
6.1 KiB
Python
181 lines
6.1 KiB
Python
|
"""Module to implement logics used for initializing extensions.
|
||
|
|
||
|
The implementations here should be stateless.
|
||
|
They should not depend on external state.
|
||
|
Anything that depends on external state should happen in __init__.py
|
||
|
"""
|
||
|
import importlib
|
||
|
import logging
|
||
|
import os
|
||
|
import types
|
||
|
from pathlib import Path
|
||
|
|
||
|
import torch
|
||
|
from torchaudio._internal.module_utils import eval_env
|
||
|
|
||
|
_LG = logging.getLogger(__name__)
|
||
|
_LIB_DIR = Path(__file__).parent.parent / "lib"
|
||
|
|
||
|
|
||
|
def _get_lib_path(lib: str):
|
||
|
suffix = "pyd" if os.name == "nt" else "so"
|
||
|
path = _LIB_DIR / f"{lib}.{suffix}"
|
||
|
return path
|
||
|
|
||
|
|
||
|
def _load_lib(lib: str) -> bool:
|
||
|
"""Load extension module
|
||
|
|
||
|
Note:
|
||
|
In case `torchaudio` is deployed with `pex` format, the library file
|
||
|
is not in a standard location.
|
||
|
In this case, we expect that `libtorchaudio` is available somewhere
|
||
|
in the search path of dynamic loading mechanism, so that importing
|
||
|
`_torchaudio` will have library loader find and load `libtorchaudio`.
|
||
|
This is the reason why the function should not raising an error when the library
|
||
|
file is not found.
|
||
|
|
||
|
Returns:
|
||
|
bool:
|
||
|
True if the library file is found AND the library loaded without failure.
|
||
|
False if the library file is not found (like in the case where torchaudio
|
||
|
is deployed with pex format, thus the shared library file is
|
||
|
in a non-standard location.).
|
||
|
If the library file is found but there is an issue loading the library,
|
||
|
(such as missing dependency) then this function raises the exception as-is.
|
||
|
|
||
|
Raises:
|
||
|
Exception:
|
||
|
If the library file is found, but there is an issue loading the library file,
|
||
|
(when underlying `ctype.DLL` throws an exception), this function will pass
|
||
|
the exception as-is, instead of catching it and returning bool.
|
||
|
The expected case is `OSError` thrown by `ctype.DLL` when a dynamic dependency
|
||
|
is not found.
|
||
|
This behavior was chosen because the expected failure case is not recoverable.
|
||
|
If a dependency is missing, then users have to install it.
|
||
|
"""
|
||
|
path = _get_lib_path(lib)
|
||
|
if not path.exists():
|
||
|
return False
|
||
|
torch.ops.load_library(path)
|
||
|
return True
|
||
|
|
||
|
|
||
|
def _import_sox_ext():
|
||
|
if os.name == "nt":
|
||
|
raise RuntimeError("sox extension is not supported on Windows")
|
||
|
if not eval_env("TORCHAUDIO_USE_SOX", True):
|
||
|
raise RuntimeError("sox extension is disabled. (TORCHAUDIO_USE_SOX=0)")
|
||
|
|
||
|
ext = "torchaudio.lib._torchaudio_sox"
|
||
|
|
||
|
if not importlib.util.find_spec(ext):
|
||
|
raise RuntimeError(
|
||
|
# fmt: off
|
||
|
"TorchAudio is not built with sox extension. "
|
||
|
"Please build TorchAudio with libsox support. (BUILD_SOX=1)"
|
||
|
# fmt: on
|
||
|
)
|
||
|
|
||
|
_load_lib("libtorchaudio_sox")
|
||
|
return importlib.import_module(ext)
|
||
|
|
||
|
|
||
|
def _init_sox():
|
||
|
ext = _import_sox_ext()
|
||
|
ext.set_verbosity(0)
|
||
|
|
||
|
import atexit
|
||
|
|
||
|
torch.ops.torchaudio_sox.initialize_sox_effects()
|
||
|
atexit.register(torch.ops.torchaudio_sox.shutdown_sox_effects)
|
||
|
|
||
|
# Bundle functions registered with TORCH_LIBRARY into extension
|
||
|
# so that they can also be accessed in the same (lazy) manner
|
||
|
# from the extension.
|
||
|
keys = [
|
||
|
"get_info",
|
||
|
"load_audio_file",
|
||
|
"save_audio_file",
|
||
|
"apply_effects_tensor",
|
||
|
"apply_effects_file",
|
||
|
]
|
||
|
for key in keys:
|
||
|
setattr(ext, key, getattr(torch.ops.torchaudio_sox, key))
|
||
|
|
||
|
return ext
|
||
|
|
||
|
|
||
|
class _LazyImporter(types.ModuleType):
|
||
|
"""Lazily import module/extension."""
|
||
|
|
||
|
def __init__(self, name, import_func):
|
||
|
super().__init__(name)
|
||
|
self.import_func = import_func
|
||
|
self.module = None
|
||
|
|
||
|
# Note:
|
||
|
# Python caches what was retrieved with `__getattr__`, so this method will not be
|
||
|
# called again for the same item.
|
||
|
def __getattr__(self, item):
|
||
|
self._import_once()
|
||
|
return getattr(self.module, item)
|
||
|
|
||
|
def __repr__(self):
|
||
|
if self.module is None:
|
||
|
return f"<module '{self.__module__}.{self.__class__.__name__}(\"{self.name}\")'>"
|
||
|
return repr(self.module)
|
||
|
|
||
|
def __dir__(self):
|
||
|
self._import_once()
|
||
|
return dir(self.module)
|
||
|
|
||
|
def _import_once(self):
|
||
|
if self.module is None:
|
||
|
self.module = self.import_func()
|
||
|
# Note:
|
||
|
# By attaching the module attributes to self,
|
||
|
# module attributes are directly accessible.
|
||
|
# This allows to avoid calling __getattr__ for every attribute access.
|
||
|
self.__dict__.update(self.module.__dict__)
|
||
|
|
||
|
def is_available(self):
|
||
|
try:
|
||
|
self._import_once()
|
||
|
except Exception:
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
|
||
|
def _init_dll_path():
|
||
|
# On Windows Python-3.8+ has `os.add_dll_directory` call,
|
||
|
# which is called to configure dll search path.
|
||
|
# To find cuda related dlls we need to make sure the
|
||
|
# conda environment/bin path is configured Please take a look:
|
||
|
# https://stackoverflow.com/questions/59330863/cant-import-dll-module-in-python
|
||
|
# Please note: if some path can't be added using add_dll_directory we simply ignore this path
|
||
|
for path in os.environ.get("PATH", "").split(";"):
|
||
|
if os.path.exists(path):
|
||
|
try:
|
||
|
os.add_dll_directory(path)
|
||
|
except Exception:
|
||
|
pass
|
||
|
|
||
|
|
||
|
def _check_cuda_version():
|
||
|
import torchaudio.lib._torchaudio
|
||
|
|
||
|
version = torchaudio.lib._torchaudio.cuda_version()
|
||
|
if version is not None and torch.version.cuda is not None:
|
||
|
version_str = str(version)
|
||
|
ta_version = f"{version_str[:-3]}.{version_str[-2]}"
|
||
|
t_version = torch.version.cuda.split(".")
|
||
|
t_version = f"{t_version[0]}.{t_version[1]}"
|
||
|
if ta_version != t_version:
|
||
|
raise RuntimeError(
|
||
|
"Detected that PyTorch and TorchAudio were compiled with different CUDA versions. "
|
||
|
f"PyTorch has CUDA version {t_version} whereas TorchAudio has CUDA version {ta_version}. "
|
||
|
"Please install the TorchAudio version that matches your PyTorch version."
|
||
|
)
|
||
|
return version
|