765 lines
32 KiB
Python
765 lines
32 KiB
Python
import contextlib
|
|
import errno
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import re
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
import torch
|
|
import uuid
|
|
import warnings
|
|
import zipfile
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Any
|
|
from urllib.error import HTTPError, URLError
|
|
from urllib.request import urlopen, Request
|
|
from urllib.parse import urlparse # noqa: F401
|
|
from torch.serialization import MAP_LOCATION
|
|
|
|
class _Faketqdm: # type: ignore[no-redef]
|
|
|
|
def __init__(self, total=None, disable=False,
|
|
unit=None, *args, **kwargs):
|
|
self.total = total
|
|
self.disable = disable
|
|
self.n = 0
|
|
# Ignore all extra *args and **kwargs lest you want to reinvent tqdm
|
|
|
|
def update(self, n):
|
|
if self.disable:
|
|
return
|
|
|
|
self.n += n
|
|
if self.total is None:
|
|
sys.stderr.write(f"\r{self.n:.1f} bytes")
|
|
else:
|
|
sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%")
|
|
sys.stderr.flush()
|
|
|
|
# Don't bother implementing; use real tqdm if you want
|
|
def set_description(self, *args, **kwargs):
|
|
pass
|
|
|
|
def write(self, s):
|
|
sys.stderr.write(f"{s}\n")
|
|
|
|
def close(self):
|
|
self.disable = True
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self.disable:
|
|
return
|
|
|
|
sys.stderr.write('\n')
|
|
|
|
try:
|
|
from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper
|
|
except ImportError:
|
|
tqdm = _Faketqdm
|
|
|
|
__all__ = [
|
|
'download_url_to_file',
|
|
'get_dir',
|
|
'help',
|
|
'list',
|
|
'load',
|
|
'load_state_dict_from_url',
|
|
'set_dir',
|
|
]
|
|
|
|
# matches bfd8deac from resnet18-bfd8deac.pth
|
|
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
|
|
|
|
_TRUSTED_REPO_OWNERS = ("facebookresearch", "facebookincubator", "pytorch", "fairinternal")
|
|
ENV_GITHUB_TOKEN = 'GITHUB_TOKEN'
|
|
ENV_TORCH_HOME = 'TORCH_HOME'
|
|
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
|
DEFAULT_CACHE_DIR = '~/.cache'
|
|
VAR_DEPENDENCY = 'dependencies'
|
|
MODULE_HUBCONF = 'hubconf.py'
|
|
READ_DATA_CHUNK = 128 * 1024
|
|
_hub_dir: Optional[str] = None
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _add_to_sys_path(path):
|
|
sys.path.insert(0, path)
|
|
try:
|
|
yield
|
|
finally:
|
|
sys.path.remove(path)
|
|
|
|
|
|
# Copied from tools/shared/module_loader to be included in torch package
|
|
def _import_module(name, path):
|
|
import importlib.util
|
|
from importlib.abc import Loader
|
|
spec = importlib.util.spec_from_file_location(name, path)
|
|
assert spec is not None
|
|
module = importlib.util.module_from_spec(spec)
|
|
assert isinstance(spec.loader, Loader)
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
|
|
def _remove_if_exists(path):
|
|
if os.path.exists(path):
|
|
if os.path.isfile(path):
|
|
os.remove(path)
|
|
else:
|
|
shutil.rmtree(path)
|
|
|
|
|
|
def _git_archive_link(repo_owner, repo_name, ref):
|
|
# See https://docs.github.com/en/rest/reference/repos#download-a-repository-archive-zip
|
|
return f"https://github.com/{repo_owner}/{repo_name}/zipball/{ref}"
|
|
|
|
|
|
def _load_attr_from_module(module, func_name):
|
|
# Check if callable is defined in the module
|
|
if func_name not in dir(module):
|
|
return None
|
|
return getattr(module, func_name)
|
|
|
|
|
|
def _get_torch_home():
|
|
torch_home = os.path.expanduser(
|
|
os.getenv(ENV_TORCH_HOME,
|
|
os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
|
|
DEFAULT_CACHE_DIR), 'torch')))
|
|
return torch_home
|
|
|
|
|
|
def _parse_repo_info(github):
|
|
if ':' in github:
|
|
repo_info, ref = github.split(':')
|
|
else:
|
|
repo_info, ref = github, None
|
|
repo_owner, repo_name = repo_info.split('/')
|
|
|
|
if ref is None:
|
|
# The ref wasn't specified by the user, so we need to figure out the
|
|
# default branch: main or master. Our assumption is that if main exists
|
|
# then it's the default branch, otherwise it's master.
|
|
try:
|
|
with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
|
|
ref = 'main'
|
|
except HTTPError as e:
|
|
if e.code == 404:
|
|
ref = 'master'
|
|
else:
|
|
raise
|
|
except URLError as e:
|
|
# No internet connection, need to check for cache as last resort
|
|
for possible_ref in ("main", "master"):
|
|
if os.path.exists(f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"):
|
|
ref = possible_ref
|
|
break
|
|
if ref is None:
|
|
raise RuntimeError(
|
|
"It looks like there is no internet connection and the "
|
|
f"repo could not be found in the cache ({get_dir()})"
|
|
) from e
|
|
return repo_owner, repo_name, ref
|
|
|
|
|
|
def _read_url(url):
|
|
with urlopen(url) as r:
|
|
return r.read().decode(r.headers.get_content_charset('utf-8'))
|
|
|
|
|
|
def _validate_not_a_forked_repo(repo_owner, repo_name, ref):
|
|
# Use urlopen to avoid depending on local git.
|
|
headers = {'Accept': 'application/vnd.github.v3+json'}
|
|
token = os.environ.get(ENV_GITHUB_TOKEN)
|
|
if token is not None:
|
|
headers['Authorization'] = f'token {token}'
|
|
for url_prefix in (
|
|
f'https://api.github.com/repos/{repo_owner}/{repo_name}/branches',
|
|
f'https://api.github.com/repos/{repo_owner}/{repo_name}/tags'):
|
|
page = 0
|
|
while True:
|
|
page += 1
|
|
url = f'{url_prefix}?per_page=100&page={page}'
|
|
response = json.loads(_read_url(Request(url, headers=headers)))
|
|
# Empty response means no more data to process
|
|
if not response:
|
|
break
|
|
for br in response:
|
|
if br['name'] == ref or br['commit']['sha'].startswith(ref):
|
|
return
|
|
|
|
raise ValueError(f'Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. '
|
|
'If it\'s a commit from a forked repo, please call hub.load() with forked repo directly.')
|
|
|
|
|
|
def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False):
|
|
# Setup hub_dir to save downloaded files
|
|
hub_dir = get_dir()
|
|
os.makedirs(hub_dir, exist_ok=True)
|
|
# Parse github repo information
|
|
repo_owner, repo_name, ref = _parse_repo_info(github)
|
|
# Github allows branch name with slash '/',
|
|
# this causes confusion with path on both Linux and Windows.
|
|
# Backslash is not allowed in Github branch name so no need to
|
|
# to worry about it.
|
|
normalized_br = ref.replace('/', '_')
|
|
# Github renames folder repo-v1.x.x to repo-1.x.x
|
|
# We don't know the repo name before downloading the zip file
|
|
# and inspect name from it.
|
|
# To check if cached repo exists, we need to normalize folder names.
|
|
owner_name_branch = '_'.join([repo_owner, repo_name, normalized_br])
|
|
repo_dir = os.path.join(hub_dir, owner_name_branch)
|
|
# Check that the repo is in the trusted list
|
|
_check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo=trust_repo, calling_fn=calling_fn)
|
|
|
|
use_cache = (not force_reload) and os.path.exists(repo_dir)
|
|
|
|
if use_cache:
|
|
if verbose:
|
|
sys.stderr.write(f'Using cache found in {repo_dir}\n')
|
|
else:
|
|
# Validate the tag/branch is from the original repo instead of a forked repo
|
|
if not skip_validation:
|
|
_validate_not_a_forked_repo(repo_owner, repo_name, ref)
|
|
|
|
cached_file = os.path.join(hub_dir, normalized_br + '.zip')
|
|
_remove_if_exists(cached_file)
|
|
|
|
try:
|
|
url = _git_archive_link(repo_owner, repo_name, ref)
|
|
sys.stderr.write(f'Downloading: \"{url}\" to {cached_file}\n')
|
|
download_url_to_file(url, cached_file, progress=False)
|
|
except HTTPError as err:
|
|
if err.code == 300:
|
|
# Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch
|
|
# in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags
|
|
# See https://git-scm.com/book/en/v2/Git-Internals-Git-References
|
|
# Here, we do the same as git: we throw a warning, and assume the user wanted the branch
|
|
warnings.warn(
|
|
f"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? "
|
|
"Torchhub will now assume that it's a branch. "
|
|
"You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or "
|
|
"refs/tags/tag_name as the ref. That might require using skip_validation=True."
|
|
)
|
|
disambiguated_branch_ref = f"refs/heads/{ref}"
|
|
url = _git_archive_link(repo_owner, repo_name, ref=disambiguated_branch_ref)
|
|
download_url_to_file(url, cached_file, progress=False)
|
|
else:
|
|
raise
|
|
|
|
with zipfile.ZipFile(cached_file) as cached_zipfile:
|
|
extraced_repo_name = cached_zipfile.infolist()[0].filename
|
|
extracted_repo = os.path.join(hub_dir, extraced_repo_name)
|
|
_remove_if_exists(extracted_repo)
|
|
# Unzip the code and rename the base folder
|
|
cached_zipfile.extractall(hub_dir)
|
|
|
|
_remove_if_exists(cached_file)
|
|
_remove_if_exists(repo_dir)
|
|
shutil.move(extracted_repo, repo_dir) # rename the repo
|
|
|
|
return repo_dir
|
|
|
|
|
|
def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"):
|
|
hub_dir = get_dir()
|
|
filepath = os.path.join(hub_dir, "trusted_list")
|
|
|
|
if not os.path.exists(filepath):
|
|
Path(filepath).touch()
|
|
with open(filepath) as file:
|
|
trusted_repos = tuple(line.strip() for line in file)
|
|
|
|
# To minimize friction of introducing the new trust_repo mechanism, we consider that
|
|
# if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist)
|
|
trusted_repos_legacy = next(os.walk(hub_dir))[1]
|
|
|
|
owner_name = '_'.join([repo_owner, repo_name])
|
|
is_trusted = (
|
|
owner_name in trusted_repos
|
|
or owner_name_branch in trusted_repos_legacy
|
|
or repo_owner in _TRUSTED_REPO_OWNERS
|
|
)
|
|
|
|
# TODO: Remove `None` option in 2.0 and change the default to "check"
|
|
if trust_repo is None:
|
|
if not is_trusted:
|
|
warnings.warn(
|
|
"You are about to download and run code from an untrusted repository. In a future release, this won't "
|
|
"be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., "
|
|
"trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
|
|
f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
|
|
f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "
|
|
f"confirmation if the repo is not already trusted. This will eventually be the default behaviour")
|
|
return
|
|
|
|
if (trust_repo is False) or (trust_repo == "check" and not is_trusted):
|
|
response = input(
|
|
f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. "
|
|
"Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?")
|
|
if response.lower() in ("y", "yes"):
|
|
if is_trusted:
|
|
print("The repository is already trusted.")
|
|
elif response.lower() in ("n", "no", ""):
|
|
raise Exception("Untrusted repository.")
|
|
else:
|
|
raise ValueError(f"Unrecognized response {response}.")
|
|
|
|
# At this point we're sure that the user trusts the repo (or wants to trust it)
|
|
if not is_trusted:
|
|
with open(filepath, "a") as file:
|
|
file.write(owner_name + "\n")
|
|
|
|
|
|
def _check_module_exists(name):
|
|
import importlib.util
|
|
return importlib.util.find_spec(name) is not None
|
|
|
|
|
|
def _check_dependencies(m):
|
|
dependencies = _load_attr_from_module(m, VAR_DEPENDENCY)
|
|
|
|
if dependencies is not None:
|
|
missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]
|
|
if len(missing_deps):
|
|
raise RuntimeError(f"Missing dependencies: {', '.join(missing_deps)}")
|
|
|
|
|
|
def _load_entry_from_hubconf(m, model):
|
|
if not isinstance(model, str):
|
|
raise ValueError('Invalid input: model should be a string of function name')
|
|
|
|
# Note that if a missing dependency is imported at top level of hubconf, it will
|
|
# throw before this function. It's a chicken and egg situation where we have to
|
|
# load hubconf to know what're the dependencies, but to import hubconf it requires
|
|
# a missing package. This is fine, Python will throw proper error message for users.
|
|
_check_dependencies(m)
|
|
|
|
func = _load_attr_from_module(m, model)
|
|
|
|
if func is None or not callable(func):
|
|
raise RuntimeError(f'Cannot find callable {model} in hubconf')
|
|
|
|
return func
|
|
|
|
|
|
def get_dir():
|
|
r"""
|
|
Get the Torch Hub cache directory used for storing downloaded models & weights.
|
|
|
|
If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where
|
|
environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
|
|
``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
|
|
filesystem layout, with a default value ``~/.cache`` if the environment
|
|
variable is not set.
|
|
"""
|
|
# Issue warning to move data if old env is set
|
|
if os.getenv('TORCH_HUB'):
|
|
warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')
|
|
|
|
if _hub_dir is not None:
|
|
return _hub_dir
|
|
return os.path.join(_get_torch_home(), 'hub')
|
|
|
|
|
|
def set_dir(d):
|
|
r"""
|
|
Optionally set the Torch Hub directory used to save downloaded models & weights.
|
|
|
|
Args:
|
|
d (str): path to a local folder to save downloaded models & weights.
|
|
"""
|
|
global _hub_dir
|
|
_hub_dir = os.path.expanduser(d)
|
|
|
|
|
|
def list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True):
|
|
r"""
|
|
List all callable entrypoints available in the repo specified by ``github``.
|
|
|
|
Args:
|
|
github (str): a string with format "repo_owner/repo_name[:ref]" with an optional
|
|
ref (tag or branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if
|
|
it exists, and otherwise ``master``.
|
|
Example: 'pytorch/vision:0.10'
|
|
force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
|
|
Default is ``False``.
|
|
skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit
|
|
specified by the ``github`` argument properly belongs to the repo owner. This will make
|
|
requests to the GitHub API; you can specify a non-default GitHub token by setting the
|
|
``GITHUB_TOKEN`` environment variable. Default is ``False``.
|
|
trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``.
|
|
This parameter was introduced in v1.12 and helps ensuring that users
|
|
only run code from repos that they trust.
|
|
|
|
- If ``False``, a prompt will ask the user whether the repo should
|
|
be trusted.
|
|
- If ``True``, the repo will be added to the trusted list and loaded
|
|
without requiring explicit confirmation.
|
|
- If ``"check"``, the repo will be checked against the list of
|
|
trusted repos in the cache. If it is not present in that list, the
|
|
behaviour will fall back onto the ``trust_repo=False`` option.
|
|
- If ``None``: this will raise a warning, inviting the user to set
|
|
``trust_repo`` to either ``False``, ``True`` or ``"check"``. This
|
|
is only present for backward compatibility and will be removed in
|
|
v2.0.
|
|
|
|
Default is ``None`` and will eventually change to ``"check"`` in v2.0.
|
|
verbose (bool, optional): If ``False``, mute messages about hitting
|
|
local caches. Note that the message about first download cannot be
|
|
muted. Default is ``True``.
|
|
|
|
Returns:
|
|
list: The available callables entrypoint
|
|
|
|
Example:
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
|
|
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
|
|
"""
|
|
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=verbose,
|
|
skip_validation=skip_validation)
|
|
|
|
with _add_to_sys_path(repo_dir):
|
|
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
|
|
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
|
|
|
|
# We take functions starts with '_' as internal helper functions
|
|
entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
|
|
|
|
return entrypoints
|
|
|
|
|
|
def help(github, model, force_reload=False, skip_validation=False, trust_repo=None):
|
|
r"""
|
|
Show the docstring of entrypoint ``model``.
|
|
|
|
Args:
|
|
github (str): a string with format <repo_owner/repo_name[:ref]> with an optional
|
|
ref (a tag or a branch). If ``ref`` is not specified, the default branch is assumed
|
|
to be ``main`` if it exists, and otherwise ``master``.
|
|
Example: 'pytorch/vision:0.10'
|
|
model (str): a string of entrypoint name defined in repo's ``hubconf.py``
|
|
force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
|
|
Default is ``False``.
|
|
skip_validation (bool, optional): if ``False``, torchhub will check that the ref
|
|
specified by the ``github`` argument properly belongs to the repo owner. This will make
|
|
requests to the GitHub API; you can specify a non-default GitHub token by setting the
|
|
``GITHUB_TOKEN`` environment variable. Default is ``False``.
|
|
trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``.
|
|
This parameter was introduced in v1.12 and helps ensuring that users
|
|
only run code from repos that they trust.
|
|
|
|
- If ``False``, a prompt will ask the user whether the repo should
|
|
be trusted.
|
|
- If ``True``, the repo will be added to the trusted list and loaded
|
|
without requiring explicit confirmation.
|
|
- If ``"check"``, the repo will be checked against the list of
|
|
trusted repos in the cache. If it is not present in that list, the
|
|
behaviour will fall back onto the ``trust_repo=False`` option.
|
|
- If ``None``: this will raise a warning, inviting the user to set
|
|
``trust_repo`` to either ``False``, ``True`` or ``"check"``. This
|
|
is only present for backward compatibility and will be removed in
|
|
v2.0.
|
|
|
|
Default is ``None`` and will eventually change to ``"check"`` in v2.0.
|
|
Example:
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
|
|
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
|
|
"""
|
|
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True,
|
|
skip_validation=skip_validation)
|
|
|
|
with _add_to_sys_path(repo_dir):
|
|
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
|
|
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
|
|
|
|
entry = _load_entry_from_hubconf(hub_module, model)
|
|
|
|
return entry.__doc__
|
|
|
|
|
|
def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True,
|
|
skip_validation=False,
|
|
**kwargs):
|
|
r"""
|
|
Load a model from a github repo or a local directory.
|
|
|
|
Note: Loading a model is the typical use case, but this can also be used to
|
|
for loading other objects such as tokenizers, loss functions, etc.
|
|
|
|
If ``source`` is 'github', ``repo_or_dir`` is expected to be
|
|
of the form ``repo_owner/repo_name[:ref]`` with an optional
|
|
ref (a tag or a branch).
|
|
|
|
If ``source`` is 'local', ``repo_or_dir`` is expected to be a
|
|
path to a local directory.
|
|
|
|
Args:
|
|
repo_or_dir (str): If ``source`` is 'github',
|
|
this should correspond to a github repo with format ``repo_owner/repo_name[:ref]`` with
|
|
an optional ref (tag or branch), for example 'pytorch/vision:0.10'. If ``ref`` is not specified,
|
|
the default branch is assumed to be ``main`` if it exists, and otherwise ``master``.
|
|
If ``source`` is 'local' then it should be a path to a local directory.
|
|
model (str): the name of a callable (entrypoint) defined in the
|
|
repo/dir's ``hubconf.py``.
|
|
*args (optional): the corresponding args for callable ``model``.
|
|
source (str, optional): 'github' or 'local'. Specifies how
|
|
``repo_or_dir`` is to be interpreted. Default is 'github'.
|
|
trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``.
|
|
This parameter was introduced in v1.12 and helps ensuring that users
|
|
only run code from repos that they trust.
|
|
|
|
- If ``False``, a prompt will ask the user whether the repo should
|
|
be trusted.
|
|
- If ``True``, the repo will be added to the trusted list and loaded
|
|
without requiring explicit confirmation.
|
|
- If ``"check"``, the repo will be checked against the list of
|
|
trusted repos in the cache. If it is not present in that list, the
|
|
behaviour will fall back onto the ``trust_repo=False`` option.
|
|
- If ``None``: this will raise a warning, inviting the user to set
|
|
``trust_repo`` to either ``False``, ``True`` or ``"check"``. This
|
|
is only present for backward compatibility and will be removed in
|
|
v2.0.
|
|
|
|
Default is ``None`` and will eventually change to ``"check"`` in v2.0.
|
|
force_reload (bool, optional): whether to force a fresh download of
|
|
the github repo unconditionally. Does not have any effect if
|
|
``source = 'local'``. Default is ``False``.
|
|
verbose (bool, optional): If ``False``, mute messages about hitting
|
|
local caches. Note that the message about first download cannot be
|
|
muted. Does not have any effect if ``source = 'local'``.
|
|
Default is ``True``.
|
|
skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit
|
|
specified by the ``github`` argument properly belongs to the repo owner. This will make
|
|
requests to the GitHub API; you can specify a non-default GitHub token by setting the
|
|
``GITHUB_TOKEN`` environment variable. Default is ``False``.
|
|
**kwargs (optional): the corresponding kwargs for callable ``model``.
|
|
|
|
Returns:
|
|
The output of the ``model`` callable when called with the given
|
|
``*args`` and ``**kwargs``.
|
|
|
|
Example:
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
|
|
>>> # from a github repo
|
|
>>> repo = 'pytorch/vision'
|
|
>>> model = torch.hub.load(repo, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1')
|
|
>>> # from a local directory
|
|
>>> path = '/some/local/path/pytorch/vision'
|
|
>>> # xdoctest: +SKIP
|
|
>>> model = torch.hub.load(path, 'resnet50', weights='ResNet50_Weights.DEFAULT')
|
|
"""
|
|
source = source.lower()
|
|
|
|
if source not in ('github', 'local'):
|
|
raise ValueError(
|
|
f'Unknown source: "{source}". Allowed values: "github" | "local".')
|
|
|
|
if source == 'github':
|
|
repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
|
|
verbose=verbose, skip_validation=skip_validation)
|
|
|
|
model = _load_local(repo_or_dir, model, *args, **kwargs)
|
|
return model
|
|
|
|
|
|
def _load_local(hubconf_dir, model, *args, **kwargs):
|
|
r"""
|
|
Load a model from a local directory with a ``hubconf.py``.
|
|
|
|
Args:
|
|
hubconf_dir (str): path to a local directory that contains a
|
|
``hubconf.py``.
|
|
model (str): name of an entrypoint defined in the directory's
|
|
``hubconf.py``.
|
|
*args (optional): the corresponding args for callable ``model``.
|
|
**kwargs (optional): the corresponding kwargs for callable ``model``.
|
|
|
|
Returns:
|
|
a single model with corresponding pretrained weights.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP("stub local path")
|
|
>>> path = '/some/local/path/pytorch/vision'
|
|
>>> model = _load_local(path, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1')
|
|
"""
|
|
with _add_to_sys_path(hubconf_dir):
|
|
hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF)
|
|
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
|
|
|
|
entry = _load_entry_from_hubconf(hub_module, model)
|
|
model = entry(*args, **kwargs)
|
|
|
|
return model
|
|
|
|
|
|
def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
|
|
progress: bool = True) -> None:
|
|
r"""Download object at the given URL to a local path.
|
|
|
|
Args:
|
|
url (str): URL of the object to download
|
|
dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file``
|
|
hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.
|
|
Default: None
|
|
progress (bool, optional): whether or not to display a progress bar to stderr
|
|
Default: True
|
|
|
|
Example:
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
|
|
>>> # xdoctest: +REQUIRES(POSIX)
|
|
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
|
|
|
|
"""
|
|
file_size = None
|
|
req = Request(url, headers={"User-Agent": "torch.hub"})
|
|
u = urlopen(req)
|
|
meta = u.info()
|
|
if hasattr(meta, 'getheaders'):
|
|
content_length = meta.getheaders("Content-Length")
|
|
else:
|
|
content_length = meta.get_all("Content-Length")
|
|
if content_length is not None and len(content_length) > 0:
|
|
file_size = int(content_length[0])
|
|
|
|
# We deliberately save it in a temp file and move it after
|
|
# download is complete. This prevents a local working checkpoint
|
|
# being overridden by a broken download.
|
|
# We deliberately do not use NamedTemporaryFile to avoid restrictive
|
|
# file permissions being applied to the downloaded file.
|
|
dst = os.path.expanduser(dst)
|
|
for seq in range(tempfile.TMP_MAX):
|
|
tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial'
|
|
try:
|
|
f = open(tmp_dst, 'w+b')
|
|
except FileExistsError:
|
|
continue
|
|
break
|
|
else:
|
|
raise FileExistsError(errno.EEXIST, 'No usable temporary file name found')
|
|
|
|
try:
|
|
if hash_prefix is not None:
|
|
sha256 = hashlib.sha256()
|
|
with tqdm(total=file_size, disable=not progress,
|
|
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
|
while True:
|
|
buffer = u.read(READ_DATA_CHUNK)
|
|
if len(buffer) == 0:
|
|
break
|
|
f.write(buffer) # type: ignore[possibly-undefined]
|
|
if hash_prefix is not None:
|
|
sha256.update(buffer) # type: ignore[possibly-undefined]
|
|
pbar.update(len(buffer))
|
|
|
|
f.close()
|
|
if hash_prefix is not None:
|
|
digest = sha256.hexdigest() # type: ignore[possibly-undefined]
|
|
if digest[:len(hash_prefix)] != hash_prefix:
|
|
raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")')
|
|
shutil.move(f.name, dst)
|
|
finally:
|
|
f.close()
|
|
if os.path.exists(f.name):
|
|
os.remove(f.name)
|
|
|
|
|
|
# Hub used to support automatically extracts from zipfile manually compressed by users.
|
|
# The legacy zip format expects only one file from torch.save() < 1.6 in the zip.
|
|
# We should remove this support since zipfile is now default zipfile format for torch.save().
|
|
def _is_legacy_zip_format(filename: str) -> bool:
|
|
if zipfile.is_zipfile(filename):
|
|
infolist = zipfile.ZipFile(filename).infolist()
|
|
return len(infolist) == 1 and not infolist[0].is_dir()
|
|
return False
|
|
|
|
|
|
def _legacy_zip_load(filename: str, model_dir: str, map_location: MAP_LOCATION, weights_only: bool) -> Dict[str, Any]:
|
|
warnings.warn('Falling back to the old format < 1.6. This support will be '
|
|
'deprecated in favor of default zipfile format introduced in 1.6. '
|
|
'Please redo torch.save() to save it in the new zipfile format.')
|
|
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
|
|
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
|
|
# E.g. resnet18-5c106cde.pth which is widely used.
|
|
with zipfile.ZipFile(filename) as f:
|
|
members = f.infolist()
|
|
if len(members) != 1:
|
|
raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
|
|
f.extractall(model_dir)
|
|
extraced_name = members[0].filename
|
|
extracted_file = os.path.join(model_dir, extraced_name)
|
|
return torch.load(extracted_file, map_location=map_location, weights_only=weights_only)
|
|
|
|
|
|
def load_state_dict_from_url(
|
|
url: str,
|
|
model_dir: Optional[str] = None,
|
|
map_location: MAP_LOCATION = None,
|
|
progress: bool = True,
|
|
check_hash: bool = False,
|
|
file_name: Optional[str] = None,
|
|
weights_only: bool = False,
|
|
) -> Dict[str, Any]:
|
|
r"""Loads the Torch serialized object at the given URL.
|
|
|
|
If downloaded file is a zip file, it will be automatically
|
|
decompressed.
|
|
|
|
If the object is already present in `model_dir`, it's deserialized and
|
|
returned.
|
|
The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
|
|
``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
|
|
|
|
Args:
|
|
url (str): URL of the object to download
|
|
model_dir (str, optional): directory in which to save the object
|
|
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
|
|
progress (bool, optional): whether or not to display a progress bar to stderr.
|
|
Default: True
|
|
check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
|
|
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
|
|
digits of the SHA256 hash of the contents of the file. The hash is used to
|
|
ensure unique names and to verify the contents of the file.
|
|
Default: False
|
|
file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set.
|
|
weights_only(bool, optional): If True, only weights will be loaded and no complex pickled objects.
|
|
Recommended for untrusted sources. See :func:`~torch.load` for more details.
|
|
|
|
Example:
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
|
|
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
|
|
|
|
"""
|
|
# Issue warning to move data if old env is set
|
|
if os.getenv('TORCH_MODEL_ZOO'):
|
|
warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
|
|
|
if model_dir is None:
|
|
hub_dir = get_dir()
|
|
model_dir = os.path.join(hub_dir, 'checkpoints')
|
|
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
parts = urlparse(url)
|
|
filename = os.path.basename(parts.path)
|
|
if file_name is not None:
|
|
filename = file_name
|
|
cached_file = os.path.join(model_dir, filename)
|
|
if not os.path.exists(cached_file):
|
|
sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n')
|
|
hash_prefix = None
|
|
if check_hash:
|
|
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
|
hash_prefix = r.group(1) if r else None
|
|
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
|
|
|
if _is_legacy_zip_format(cached_file):
|
|
return _legacy_zip_load(cached_file, model_dir, map_location, weights_only)
|
|
return torch.load(cached_file, map_location=map_location, weights_only=weights_only)
|