238 lines
8.7 KiB
Python
238 lines
8.7 KiB
Python
|
import importlib
|
||
|
from abc import ABC, abstractmethod
|
||
|
from pickle import ( # type: ignore[attr-defined] # type: ignore[attr-defined]
|
||
|
_getattribute,
|
||
|
_Pickler,
|
||
|
whichmodule as _pickle_whichmodule,
|
||
|
)
|
||
|
from types import ModuleType
|
||
|
from typing import Any, Dict, List, Optional, Tuple
|
||
|
|
||
|
from ._mangling import demangle, get_mangle_prefix, is_mangled
|
||
|
|
||
|
__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"]
|
||
|
|
||
|
|
||
|
class ObjNotFoundError(Exception):
|
||
|
"""Raised when an importer cannot find an object by searching for its name."""
|
||
|
|
||
|
pass
|
||
|
|
||
|
|
||
|
class ObjMismatchError(Exception):
|
||
|
"""Raised when an importer found a different object with the same name as the user-provided one."""
|
||
|
|
||
|
pass
|
||
|
|
||
|
|
||
|
class Importer(ABC):
|
||
|
"""Represents an environment to import modules from.
|
||
|
|
||
|
By default, you can figure out what module an object belongs by checking
|
||
|
__module__ and importing the result using __import__ or importlib.import_module.
|
||
|
|
||
|
torch.package introduces module importers other than the default one.
|
||
|
Each PackageImporter introduces a new namespace. Potentially a single
|
||
|
name (e.g. 'foo.bar') is present in multiple namespaces.
|
||
|
|
||
|
It supports two main operations:
|
||
|
import_module: module_name -> module object
|
||
|
get_name: object -> (parent module name, name of obj within module)
|
||
|
|
||
|
The guarantee is that following round-trip will succeed or throw an ObjNotFoundError/ObjMisMatchError.
|
||
|
module_name, obj_name = env.get_name(obj)
|
||
|
module = env.import_module(module_name)
|
||
|
obj2 = getattr(module, obj_name)
|
||
|
assert obj1 is obj2
|
||
|
"""
|
||
|
|
||
|
modules: Dict[str, ModuleType]
|
||
|
|
||
|
@abstractmethod
|
||
|
def import_module(self, module_name: str) -> ModuleType:
|
||
|
"""Import `module_name` from this environment.
|
||
|
|
||
|
The contract is the same as for importlib.import_module.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]:
|
||
|
"""Given an object, return a name that can be used to retrieve the
|
||
|
object from this environment.
|
||
|
|
||
|
Args:
|
||
|
obj: An object to get the module-environment-relative name for.
|
||
|
name: If set, use this name instead of looking up __name__ or __qualname__ on `obj`.
|
||
|
This is only here to match how Pickler handles __reduce__ functions that return a string,
|
||
|
don't use otherwise.
|
||
|
Returns:
|
||
|
A tuple (parent_module_name, attr_name) that can be used to retrieve `obj` from this environment.
|
||
|
Use it like:
|
||
|
mod = importer.import_module(parent_module_name)
|
||
|
obj = getattr(mod, attr_name)
|
||
|
|
||
|
Raises:
|
||
|
ObjNotFoundError: we couldn't retrieve `obj by name.
|
||
|
ObjMisMatchError: we found a different object with the same name as `obj`.
|
||
|
"""
|
||
|
if name is None and obj and _Pickler.dispatch.get(type(obj)) is None:
|
||
|
# Honor the string return variant of __reduce__, which will give us
|
||
|
# a global name to search for in this environment.
|
||
|
# TODO: I guess we should do copyreg too?
|
||
|
reduce = getattr(obj, "__reduce__", None)
|
||
|
if reduce is not None:
|
||
|
try:
|
||
|
rv = reduce()
|
||
|
if isinstance(rv, str):
|
||
|
name = rv
|
||
|
except Exception:
|
||
|
pass
|
||
|
if name is None:
|
||
|
name = getattr(obj, "__qualname__", None)
|
||
|
if name is None:
|
||
|
name = obj.__name__
|
||
|
|
||
|
orig_module_name = self.whichmodule(obj, name)
|
||
|
# Demangle the module name before importing. If this obj came out of a
|
||
|
# PackageImporter, `__module__` will be mangled. See mangling.md for
|
||
|
# details.
|
||
|
module_name = demangle(orig_module_name)
|
||
|
|
||
|
# Check that this name will indeed return the correct object
|
||
|
try:
|
||
|
module = self.import_module(module_name)
|
||
|
obj2, _ = _getattribute(module, name)
|
||
|
except (ImportError, KeyError, AttributeError):
|
||
|
raise ObjNotFoundError(
|
||
|
f"{obj} was not found as {module_name}.{name}"
|
||
|
) from None
|
||
|
|
||
|
if obj is obj2:
|
||
|
return module_name, name
|
||
|
|
||
|
def get_obj_info(obj):
|
||
|
assert name is not None
|
||
|
module_name = self.whichmodule(obj, name)
|
||
|
is_mangled_ = is_mangled(module_name)
|
||
|
location = (
|
||
|
get_mangle_prefix(module_name)
|
||
|
if is_mangled_
|
||
|
else "the current Python environment"
|
||
|
)
|
||
|
importer_name = (
|
||
|
f"the importer for {get_mangle_prefix(module_name)}"
|
||
|
if is_mangled_
|
||
|
else "'sys_importer'"
|
||
|
)
|
||
|
return module_name, location, importer_name
|
||
|
|
||
|
obj_module_name, obj_location, obj_importer_name = get_obj_info(obj)
|
||
|
obj2_module_name, obj2_location, obj2_importer_name = get_obj_info(obj2)
|
||
|
msg = (
|
||
|
f"\n\nThe object provided is from '{obj_module_name}', "
|
||
|
f"which is coming from {obj_location}."
|
||
|
f"\nHowever, when we import '{obj2_module_name}', it's coming from {obj2_location}."
|
||
|
"\nTo fix this, make sure this 'PackageExporter's importer lists "
|
||
|
f"{obj_importer_name} before {obj2_importer_name}."
|
||
|
)
|
||
|
raise ObjMismatchError(msg)
|
||
|
|
||
|
def whichmodule(self, obj: Any, name: str) -> str:
|
||
|
"""Find the module name an object belongs to.
|
||
|
|
||
|
This should be considered internal for end-users, but developers of
|
||
|
an importer can override it to customize the behavior.
|
||
|
|
||
|
Taken from pickle.py, but modified to exclude the search into sys.modules
|
||
|
"""
|
||
|
module_name = getattr(obj, "__module__", None)
|
||
|
if module_name is not None:
|
||
|
return module_name
|
||
|
|
||
|
# Protect the iteration by using a list copy of self.modules against dynamic
|
||
|
# modules that trigger imports of other modules upon calls to getattr.
|
||
|
for module_name, module in self.modules.copy().items():
|
||
|
if (
|
||
|
module_name == "__main__"
|
||
|
or module_name == "__mp_main__" # bpo-42406
|
||
|
or module is None
|
||
|
):
|
||
|
continue
|
||
|
try:
|
||
|
if _getattribute(module, name)[0] is obj:
|
||
|
return module_name
|
||
|
except AttributeError:
|
||
|
pass
|
||
|
|
||
|
return "__main__"
|
||
|
|
||
|
|
||
|
class _SysImporter(Importer):
|
||
|
"""An importer that implements the default behavior of Python."""
|
||
|
|
||
|
def import_module(self, module_name: str):
|
||
|
return importlib.import_module(module_name)
|
||
|
|
||
|
def whichmodule(self, obj: Any, name: str) -> str:
|
||
|
return _pickle_whichmodule(obj, name)
|
||
|
|
||
|
|
||
|
sys_importer = _SysImporter()
|
||
|
|
||
|
|
||
|
class OrderedImporter(Importer):
|
||
|
"""A compound importer that takes a list of importers and tries them one at a time.
|
||
|
|
||
|
The first importer in the list that returns a result "wins".
|
||
|
"""
|
||
|
|
||
|
def __init__(self, *args):
|
||
|
self._importers: List[Importer] = list(args)
|
||
|
|
||
|
def _is_torchpackage_dummy(self, module):
|
||
|
"""Returns true iff this module is an empty PackageNode in a torch.package.
|
||
|
|
||
|
If you intern `a.b` but never use `a` in your code, then `a` will be an
|
||
|
empty module with no source. This can break cases where we are trying to
|
||
|
re-package an object after adding a real dependency on `a`, since
|
||
|
OrderedImportere will resolve `a` to the dummy package and stop there.
|
||
|
|
||
|
See: https://github.com/pytorch/pytorch/pull/71520#issuecomment-1029603769
|
||
|
"""
|
||
|
if not getattr(module, "__torch_package__", False):
|
||
|
return False
|
||
|
if not hasattr(module, "__path__"):
|
||
|
return False
|
||
|
if not hasattr(module, "__file__"):
|
||
|
return True
|
||
|
return module.__file__ is None
|
||
|
|
||
|
def import_module(self, module_name: str) -> ModuleType:
|
||
|
last_err = None
|
||
|
for importer in self._importers:
|
||
|
if not isinstance(importer, Importer):
|
||
|
raise TypeError(
|
||
|
f"{importer} is not a Importer. "
|
||
|
"All importers in OrderedImporter must inherit from Importer."
|
||
|
)
|
||
|
try:
|
||
|
module = importer.import_module(module_name)
|
||
|
if self._is_torchpackage_dummy(module):
|
||
|
continue
|
||
|
return module
|
||
|
except ModuleNotFoundError as err:
|
||
|
last_err = err
|
||
|
|
||
|
if last_err is not None:
|
||
|
raise last_err
|
||
|
else:
|
||
|
raise ModuleNotFoundError(module_name)
|
||
|
|
||
|
def whichmodule(self, obj: Any, name: str) -> str:
|
||
|
for importer in self._importers:
|
||
|
module_name = importer.whichmodule(obj, name)
|
||
|
if module_name != "__main__":
|
||
|
return module_name
|
||
|
|
||
|
return "__main__"
|