278 lines
9.8 KiB
Python
278 lines
9.8 KiB
Python
import fnmatch
|
|
import importlib
|
|
import inspect
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from functools import partial
|
|
from inspect import signature
|
|
from types import ModuleType
|
|
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
|
|
|
|
from torch import nn
|
|
|
|
from .._internally_replaced_utils import load_state_dict_from_url
|
|
|
|
|
|
__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"]
|
|
|
|
|
|
@dataclass
|
|
class Weights:
|
|
"""
|
|
This class is used to group important attributes associated with the pre-trained weights.
|
|
|
|
Args:
|
|
url (str): The location where we find the weights.
|
|
transforms (Callable): A callable that constructs the preprocessing method (or validation preset transforms)
|
|
needed to use the model. The reason we attach a constructor method rather than an already constructed
|
|
object is because the specific object might have memory and thus we want to delay initialization until
|
|
needed.
|
|
meta (Dict[str, Any]): Stores meta-data related to the weights of the model and its configuration. These can be
|
|
informative attributes (for example the number of parameters/flops, recipe link/methods used in training
|
|
etc), configuration parameters (for example the `num_classes`) needed to construct the model or important
|
|
meta-data (for example the `classes` of a classification model) needed to use the model.
|
|
"""
|
|
|
|
url: str
|
|
transforms: Callable
|
|
meta: Dict[str, Any]
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
# We need this custom implementation for correct deep-copy and deserialization behavior.
|
|
# TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it,
|
|
# involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often
|
|
# defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling
|
|
# for it, the check against the defined members would fail and effectively prevent the weights from being
|
|
# deep-copied or deserialized.
|
|
# See https://github.com/pytorch/vision/pull/7107 for details.
|
|
if not isinstance(other, Weights):
|
|
return NotImplemented
|
|
|
|
if self.url != other.url:
|
|
return False
|
|
|
|
if self.meta != other.meta:
|
|
return False
|
|
|
|
if isinstance(self.transforms, partial) and isinstance(other.transforms, partial):
|
|
return (
|
|
self.transforms.func == other.transforms.func
|
|
and self.transforms.args == other.transforms.args
|
|
and self.transforms.keywords == other.transforms.keywords
|
|
)
|
|
else:
|
|
return self.transforms == other.transforms
|
|
|
|
|
|
class WeightsEnum(Enum):
|
|
"""
|
|
This class is the parent class of all model weights. Each model building method receives an optional `weights`
|
|
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
|
|
`Weights`.
|
|
|
|
Args:
|
|
value (Weights): The data class entry with the weight information.
|
|
"""
|
|
|
|
@classmethod
|
|
def verify(cls, obj: Any) -> Any:
|
|
if obj is not None:
|
|
if type(obj) is str:
|
|
obj = cls[obj.replace(cls.__name__ + ".", "")]
|
|
elif not isinstance(obj, cls):
|
|
raise TypeError(
|
|
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
|
|
)
|
|
return obj
|
|
|
|
def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:
|
|
return load_state_dict_from_url(self.url, *args, **kwargs)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}.{self._name_}"
|
|
|
|
@property
|
|
def url(self):
|
|
return self.value.url
|
|
|
|
@property
|
|
def transforms(self):
|
|
return self.value.transforms
|
|
|
|
@property
|
|
def meta(self):
|
|
return self.value.meta
|
|
|
|
|
|
def get_weight(name: str) -> WeightsEnum:
|
|
"""
|
|
Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
|
|
|
|
Args:
|
|
name (str): The name of the weight enum entry.
|
|
|
|
Returns:
|
|
WeightsEnum: The requested weight enum.
|
|
"""
|
|
try:
|
|
enum_name, value_name = name.split(".")
|
|
except ValueError:
|
|
raise ValueError(f"Invalid weight name provided: '{name}'.")
|
|
|
|
base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1])
|
|
base_module = importlib.import_module(base_module_name)
|
|
model_modules = [base_module] + [
|
|
x[1]
|
|
for x in inspect.getmembers(base_module, inspect.ismodule)
|
|
if x[1].__file__.endswith("__init__.py") # type: ignore[union-attr]
|
|
]
|
|
|
|
weights_enum = None
|
|
for m in model_modules:
|
|
potential_class = m.__dict__.get(enum_name, None)
|
|
if potential_class is not None and issubclass(potential_class, WeightsEnum):
|
|
weights_enum = potential_class
|
|
break
|
|
|
|
if weights_enum is None:
|
|
raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")
|
|
|
|
return weights_enum[value_name]
|
|
|
|
|
|
def get_model_weights(name: Union[Callable, str]) -> Type[WeightsEnum]:
|
|
"""
|
|
Returns the weights enum class associated to the given model.
|
|
|
|
Args:
|
|
name (callable or str): The model builder function or the name under which it is registered.
|
|
|
|
Returns:
|
|
weights_enum (WeightsEnum): The weights enum class associated with the model.
|
|
"""
|
|
model = get_model_builder(name) if isinstance(name, str) else name
|
|
return _get_enum_from_fn(model)
|
|
|
|
|
|
def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]:
|
|
"""
|
|
Internal method that gets the weight enum of a specific model builder method.
|
|
|
|
Args:
|
|
fn (Callable): The builder method used to create the model.
|
|
Returns:
|
|
WeightsEnum: The requested weight enum.
|
|
"""
|
|
sig = signature(fn)
|
|
if "weights" not in sig.parameters:
|
|
raise ValueError("The method is missing the 'weights' argument.")
|
|
|
|
ann = signature(fn).parameters["weights"].annotation
|
|
weights_enum = None
|
|
if isinstance(ann, type) and issubclass(ann, WeightsEnum):
|
|
weights_enum = ann
|
|
else:
|
|
# handle cases like Union[Optional, T]
|
|
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
|
|
for t in ann.__args__: # type: ignore[union-attr]
|
|
if isinstance(t, type) and issubclass(t, WeightsEnum):
|
|
weights_enum = t
|
|
break
|
|
|
|
if weights_enum is None:
|
|
raise ValueError(
|
|
"The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
|
|
)
|
|
|
|
return weights_enum
|
|
|
|
|
|
M = TypeVar("M", bound=nn.Module)
|
|
|
|
BUILTIN_MODELS = {}
|
|
|
|
|
|
def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
|
|
def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
|
|
key = name if name is not None else fn.__name__
|
|
if key in BUILTIN_MODELS:
|
|
raise ValueError(f"An entry is already registered under the name '{key}'.")
|
|
BUILTIN_MODELS[key] = fn
|
|
return fn
|
|
|
|
return wrapper
|
|
|
|
|
|
def list_models(
|
|
module: Optional[ModuleType] = None,
|
|
include: Union[Iterable[str], str, None] = None,
|
|
exclude: Union[Iterable[str], str, None] = None,
|
|
) -> List[str]:
|
|
"""
|
|
Returns a list with the names of registered models.
|
|
|
|
Args:
|
|
module (ModuleType, optional): The module from which we want to extract the available models.
|
|
include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models.
|
|
Filters are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
|
|
wildcards. In case of many filters, the results is the union of individual filters.
|
|
exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models.
|
|
Filter are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
|
|
wildcards. In case of many filters, the results is removal of all the models that match any individual filter.
|
|
|
|
Returns:
|
|
models (list): A list with the names of available models.
|
|
"""
|
|
all_models = {
|
|
k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
|
|
}
|
|
if include:
|
|
models: Set[str] = set()
|
|
if isinstance(include, str):
|
|
include = [include]
|
|
for include_filter in include:
|
|
models = models | set(fnmatch.filter(all_models, include_filter))
|
|
else:
|
|
models = all_models
|
|
|
|
if exclude:
|
|
if isinstance(exclude, str):
|
|
exclude = [exclude]
|
|
for exclude_filter in exclude:
|
|
models = models - set(fnmatch.filter(all_models, exclude_filter))
|
|
return sorted(models)
|
|
|
|
|
|
def get_model_builder(name: str) -> Callable[..., nn.Module]:
|
|
"""
|
|
Gets the model name and returns the model builder method.
|
|
|
|
Args:
|
|
name (str): The name under which the model is registered.
|
|
|
|
Returns:
|
|
fn (Callable): The model builder method.
|
|
"""
|
|
name = name.lower()
|
|
try:
|
|
fn = BUILTIN_MODELS[name]
|
|
except KeyError:
|
|
raise ValueError(f"Unknown model {name}")
|
|
return fn
|
|
|
|
|
|
def get_model(name: str, **config: Any) -> nn.Module:
|
|
"""
|
|
Gets the model name and configuration and returns an instantiated model.
|
|
|
|
Args:
|
|
name (str): The name under which the model is registered.
|
|
**config (Any): parameters passed to the model builder method.
|
|
|
|
Returns:
|
|
model (nn.Module): The initialized model.
|
|
"""
|
|
fn = get_model_builder(name)
|
|
return fn(**config)
|