Traktor/myenv/Lib/site-packages/torchvision/datasets/places365.py

172 lines
7.1 KiB
Python
Raw Permalink Normal View History

2024-05-26 05:12:46 +02:00
import os
from os import path
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from urllib.parse import urljoin
from .folder import default_loader
from .utils import check_integrity, download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
class Places365(VisionDataset):
r"""`Places365 <http://places2.csail.mit.edu/index.html>`_ classification dataset.
Args:
root (str or ``pathlib.Path``): Root directory of the Places365 dataset.
split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``,
``val``.
small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the
high resolution ones.
download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
downloaded archives are not downloaded again.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
targets (list): The class_index value for each image in the dataset
Raises:
RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted.
RuntimeError: If ``download is True`` and the image archive is already extracted.
"""
_SPLITS = ("train-standard", "train-challenge", "val")
_BASE_URL = "http://data.csail.mit.edu/places/places365/"
# {variant: (archive, md5)}
_DEVKIT_META = {
"standard": ("filelist_places365-standard.tar", "35a0585fee1fa656440f3ab298f8479c"),
"challenge": ("filelist_places365-challenge.tar", "70a8307e459c3de41690a7c76c931734"),
}
# (file, md5)
_CATEGORIES_META = ("categories_places365.txt", "06c963b85866bd0649f97cb43dd16673")
# {split: (file, md5)}
_FILE_LIST_META = {
"train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
"train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
"val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
}
# {(split, small): (file, md5)}
_IMAGES_META = {
("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
}
def __init__(
self,
root: Union[str, Path],
split: str = "train-standard",
small: bool = False,
download: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.split = self._verify_split(split)
self.small = small
self.loader = loader
self.classes, self.class_to_idx = self.load_categories(download)
self.imgs, self.targets = self.load_file_list(download)
if download:
self.download_images()
def __getitem__(self, index: int) -> Tuple[Any, Any]:
file, target = self.imgs[index]
image = self.loader(file)
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self) -> int:
return len(self.imgs)
@property
def variant(self) -> str:
return "challenge" if "challenge" in self.split else "standard"
@property
def images_dir(self) -> str:
size = "256" if self.small else "large"
if self.split.startswith("train"):
dir = f"data_{size}_{self.variant}"
else:
dir = f"{self.split}_{size}"
return path.join(self.root, dir)
def load_categories(self, download: bool = True) -> Tuple[List[str], Dict[str, int]]:
def process(line: str) -> Tuple[str, int]:
cls, idx = line.split()
return cls, int(idx)
file, md5 = self._CATEGORIES_META
file = path.join(self.root, file)
if not self._check_integrity(file, md5, download):
self.download_devkit()
with open(file) as fh:
class_to_idx = dict(process(line) for line in fh)
return sorted(class_to_idx.keys()), class_to_idx
def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]:
def process(line: str, sep="/") -> Tuple[str, int]:
image, idx = line.split()
return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx)
file, md5 = self._FILE_LIST_META[self.split]
file = path.join(self.root, file)
if not self._check_integrity(file, md5, download):
self.download_devkit()
with open(file) as fh:
images = [process(line) for line in fh]
_, targets = zip(*images)
return images, list(targets)
def download_devkit(self) -> None:
file, md5 = self._DEVKIT_META[self.variant]
download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
def download_images(self) -> None:
if path.exists(self.images_dir):
raise RuntimeError(
f"The directory {self.images_dir} already exists. If you want to re-download or re-extract the images, "
f"delete the directory."
)
file, md5 = self._IMAGES_META[(self.split, self.small)]
download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
if self.split.startswith("train"):
os.rename(self.images_dir.rsplit("_", 1)[0], self.images_dir)
def extra_repr(self) -> str:
return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__)
def _verify_split(self, split: str) -> str:
return verify_str_arg(split, "split", self._SPLITS)
def _check_integrity(self, file: str, md5: str, download: bool) -> bool:
integrity = check_integrity(file, md5=md5)
if not integrity and not download:
raise RuntimeError(
f"The file {file} does not exist or is corrupted. You can set download=True to download it."
)
return integrity