112 lines
4.1 KiB
Python
112 lines
4.1 KiB
Python
import os
|
|
from pathlib import Path
|
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
|
|
|
import torch.utils.data as data
|
|
|
|
from ..utils import _log_api_usage_once
|
|
|
|
|
|
class VisionDataset(data.Dataset):
|
|
"""
|
|
Base Class For making datasets which are compatible with torchvision.
|
|
It is necessary to override the ``__getitem__`` and ``__len__`` method.
|
|
|
|
Args:
|
|
root (string, optional): Root directory of dataset. Only used for `__repr__`.
|
|
transforms (callable, optional): A function/transforms that takes in
|
|
an image and a label and returns the transformed versions of both.
|
|
transform (callable, optional): A function/transform that takes in a PIL image
|
|
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
|
target_transform (callable, optional): A function/transform that takes in the
|
|
target and transforms it.
|
|
|
|
.. note::
|
|
|
|
:attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
|
|
"""
|
|
|
|
_repr_indent = 4
|
|
|
|
def __init__(
|
|
self,
|
|
root: Union[str, Path] = None, # type: ignore[assignment]
|
|
transforms: Optional[Callable] = None,
|
|
transform: Optional[Callable] = None,
|
|
target_transform: Optional[Callable] = None,
|
|
) -> None:
|
|
_log_api_usage_once(self)
|
|
if isinstance(root, str):
|
|
root = os.path.expanduser(root)
|
|
self.root = root
|
|
|
|
has_transforms = transforms is not None
|
|
has_separate_transform = transform is not None or target_transform is not None
|
|
if has_transforms and has_separate_transform:
|
|
raise ValueError("Only transforms or transform/target_transform can be passed as argument")
|
|
|
|
# for backwards-compatibility
|
|
self.transform = transform
|
|
self.target_transform = target_transform
|
|
|
|
if has_separate_transform:
|
|
transforms = StandardTransform(transform, target_transform)
|
|
self.transforms = transforms
|
|
|
|
def __getitem__(self, index: int) -> Any:
|
|
"""
|
|
Args:
|
|
index (int): Index
|
|
|
|
Returns:
|
|
(Any): Sample and meta data, optionally transformed by the respective transforms.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def __len__(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
def __repr__(self) -> str:
|
|
head = "Dataset " + self.__class__.__name__
|
|
body = [f"Number of datapoints: {self.__len__()}"]
|
|
if self.root is not None:
|
|
body.append(f"Root location: {self.root}")
|
|
body += self.extra_repr().splitlines()
|
|
if hasattr(self, "transforms") and self.transforms is not None:
|
|
body += [repr(self.transforms)]
|
|
lines = [head] + [" " * self._repr_indent + line for line in body]
|
|
return "\n".join(lines)
|
|
|
|
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
|
|
lines = transform.__repr__().splitlines()
|
|
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
|
|
|
|
def extra_repr(self) -> str:
|
|
return ""
|
|
|
|
|
|
class StandardTransform:
|
|
def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
|
|
self.transform = transform
|
|
self.target_transform = target_transform
|
|
|
|
def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
|
|
if self.transform is not None:
|
|
input = self.transform(input)
|
|
if self.target_transform is not None:
|
|
target = self.target_transform(target)
|
|
return input, target
|
|
|
|
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
|
|
lines = transform.__repr__().splitlines()
|
|
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
|
|
|
|
def __repr__(self) -> str:
|
|
body = [self.__class__.__name__]
|
|
if self.transform is not None:
|
|
body += self._format_transform_repr(self.transform, "Transform: ")
|
|
if self.target_transform is not None:
|
|
body += self._format_transform_repr(self.target_transform, "Target transform: ")
|
|
|
|
return "\n".join(body)
|