Traktor/myenv/Lib/site-packages/torch/utils/data/dataset.py

489 lines
19 KiB
Python
Raw Normal View History

2024-05-23 01:57:24 +02:00
import bisect
import itertools
import math
import warnings
from typing import (
cast,
Dict,
Generic,
Iterable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
# No 'default_generator' in torch/__init__.pyi
from torch import default_generator, randperm
from ... import Generator, Tensor
__all__ = [
"Dataset",
"IterableDataset",
"TensorDataset",
"StackDataset",
"ConcatDataset",
"ChainDataset",
"Subset",
"random_split",
]
T_co = TypeVar("T_co", covariant=True)
T = TypeVar("T")
T_dict = Dict[str, T_co]
T_tuple = Tuple[T_co, ...]
T_stack = TypeVar("T_stack", T_tuple, T_dict)
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`. Subclasses could also
optionally implement :meth:`__getitems__`, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs an index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index) -> T_co:
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
# def __getitems__(self, indices: List) -> List[T_co]:
# Not implemented to prevent false-positives in fetcher check in
# torch.utils.data._utils.fetch._MapDatasetFetcher
def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
class IterableDataset(Dataset[T_co], Iterable[T_co]):
r"""An iterable Dataset.
All datasets that represent an iterable of data samples should subclass it.
Such form of datasets is particularly useful when data come from a stream.
All subclasses should overwrite :meth:`__iter__`, which would return an
iterator of samples in this dataset.
When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
iterator. When :attr:`num_workers > 0`, each worker process will have a
different copy of the dataset object, so it is often desired to configure
each copy independently to avoid having duplicate data returned from the
workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
process, returns information about the worker. It can be used in either the
dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
:attr:`worker_init_fn` option to modify each copy's behavior.
Example 1: splitting workload across all workers in :meth:`__iter__`::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
>>> # xdoctest: +SKIP("Fails on MacOS12")
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... worker_info = torch.utils.data.get_worker_info()
... if worker_info is None: # single-process data loading, return the full iterator
... iter_start = self.start
... iter_end = self.end
... else: # in a worker process
... # split workload
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... iter_start = self.start + worker_id * per_worker
... iter_end = min(iter_start + per_worker, self.end)
... return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]
>>> # xdoctest: +REQUIRES(POSIX)
>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> # xdoctest: +IGNORE_WANT("non deterministic")
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]
>>> # With even more workers
>>> # xdoctest: +IGNORE_WANT("non deterministic")
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]
Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]
>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
... worker_info = torch.utils.data.get_worker_info()
... dataset = worker_info.dataset # the dataset copy in this worker process
... overall_start = dataset.start
... overall_end = dataset.end
... # configure the dataset to only process the split workload
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... dataset.start = overall_start + worker_id * per_worker
... dataset.end = min(dataset.start + per_worker, overall_end)
...
>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
"""
def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other])
# No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
r"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Args:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
tensors: Tuple[Tensor, ...]
def __init__(self, *tensors: Tensor) -> None:
assert all(
tensors[0].size(0) == tensor.size(0) for tensor in tensors
), "Size mismatch between tensors"
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
class StackDataset(Dataset[T_stack]):
r"""Dataset as a stacking of multiple datasets.
This class is useful to assemble different parts of complex input data, given as datasets.
Example:
>>> # xdoctest: +SKIP
>>> images = ImageDataset()
>>> texts = TextDataset()
>>> tuple_stack = StackDataset(images, texts)
>>> tuple_stack[0] == (images[0], texts[0])
>>> dict_stack = StackDataset(image=images, text=texts)
>>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
Args:
*args (Dataset): Datasets for stacking returned as tuple.
**kwargs (Dataset): Datasets for stacking returned as dict.
"""
datasets: Union[tuple, dict]
def __init__(self, *args: Dataset[T_co], **kwargs: Dataset[T_co]) -> None:
if args:
if kwargs:
raise ValueError(
"Supported either ``tuple``- (via ``args``) or"
"``dict``- (via ``kwargs``) like input/output, but both types are given."
)
self._length = len(args[0]) # type: ignore[arg-type]
if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type]
raise ValueError("Size mismatch between datasets")
self.datasets = args
elif kwargs:
tmp = list(kwargs.values())
self._length = len(tmp[0]) # type: ignore[arg-type]
if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type]
raise ValueError("Size mismatch between datasets")
self.datasets = kwargs
else:
raise ValueError("At least one dataset should be passed")
def __getitem__(self, index):
if isinstance(self.datasets, dict):
return {k: dataset[index] for k, dataset in self.datasets.items()}
return tuple(dataset[index] for dataset in self.datasets)
def __getitems__(self, indices: list):
# add batched sampling support when parent datasets supports it.
if isinstance(self.datasets, dict):
dict_batch: List[T_dict] = [{} for _ in indices]
for k, dataset in self.datasets.items():
if callable(getattr(dataset, "__getitems__", None)):
items = dataset.__getitems__(indices) # type: ignore[attr-defined]
if len(items) != len(indices):
raise ValueError(
"Nested dataset's output size mismatch."
f" Expected {len(indices)}, got {len(items)}"
)
for data, d_sample in zip(items, dict_batch):
d_sample[k] = data
else:
for idx, d_sample in zip(indices, dict_batch):
d_sample[k] = dataset[idx]
return dict_batch
# tuple data
list_batch: List[list] = [[] for _ in indices]
for dataset in self.datasets:
if callable(getattr(dataset, "__getitems__", None)):
items = dataset.__getitems__(indices) # type: ignore[attr-defined]
if len(items) != len(indices):
raise ValueError(
"Nested dataset's output size mismatch."
f" Expected {len(indices)}, got {len(items)}"
)
for data, t_sample in zip(items, list_batch):
t_sample.append(data)
else:
for idx, t_sample in zip(indices, list_batch):
t_sample.append(dataset[idx])
tuple_batch: List[T_tuple] = [tuple(sample) for sample in list_batch]
return tuple_batch
def __len__(self):
return self._length
class ConcatDataset(Dataset[T_co]):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Args:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset]) -> None:
super().__init__()
self.datasets = list(datasets)
assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
for d in self.datasets:
assert not isinstance(
d, IterableDataset
), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
@property
def cummulative_sizes(self):
warnings.warn(
"cummulative_sizes attribute is renamed to " "cumulative_sizes",
DeprecationWarning,
stacklevel=2,
)
return self.cumulative_sizes
class ChainDataset(IterableDataset):
r"""Dataset for chaining multiple :class:`IterableDataset` s.
This class is useful to assemble different existing dataset streams. The
chaining operation is done on-the-fly, so concatenating large-scale
datasets with this class will be efficient.
Args:
datasets (iterable of IterableDataset): datasets to be chained together
"""
def __init__(self, datasets: Iterable[Dataset]) -> None:
super().__init__()
self.datasets = datasets
def __iter__(self):
for d in self.datasets:
assert isinstance(
d, IterableDataset
), "ChainDataset only supports IterableDataset"
yield from d
def __len__(self):
total = 0
for d in self.datasets:
assert isinstance(
d, IterableDataset
), "ChainDataset only supports IterableDataset"
total += len(d) # type: ignore[arg-type]
return total
class Subset(Dataset[T_co]):
r"""
Subset of a dataset at specified indices.
Args:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
dataset: Dataset[T_co]
indices: Sequence[int]
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
if isinstance(idx, list):
return self.dataset[[self.indices[i] for i in idx]]
return self.dataset[self.indices[idx]]
def __getitems__(self, indices: List[int]) -> List[T_co]:
# add batched sampling support when parent dataset supports it.
# see torch.utils.data._utils.fetch._MapDatasetFetcher
if callable(getattr(self.dataset, "__getitems__", None)):
return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined]
else:
return [self.dataset[self.indices[idx]] for idx in indices]
def __len__(self):
return len(self.indices)
def random_split(
dataset: Dataset[T],
lengths: Sequence[Union[int, float]],
generator: Optional[Generator] = default_generator,
) -> List[Subset[T]]:
r"""
Randomly split a dataset into non-overlapping new datasets of given lengths.
If a list of fractions that sum up to 1 is given,
the lengths will be computed automatically as
floor(frac * len(dataset)) for each fraction provided.
After computing the lengths, if there are any remainders, 1 count will be
distributed in round-robin fashion to the lengths
until there are no remainders left.
Optionally fix the generator for reproducible results, e.g.:
Example:
>>> # xdoctest: +SKIP
>>> generator1 = torch.Generator().manual_seed(42)
>>> generator2 = torch.Generator().manual_seed(42)
>>> random_split(range(10), [3, 7], generator=generator1)
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
Args:
dataset (Dataset): Dataset to be split
lengths (sequence): lengths or fractions of splits to be produced
generator (Generator): Generator used for the random permutation.
"""
if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
subset_lengths: List[int] = []
for i, frac in enumerate(lengths):
if frac < 0 or frac > 1:
raise ValueError(f"Fraction at index {i} is not between 0 and 1")
n_items_in_split = int(
math.floor(len(dataset) * frac) # type: ignore[arg-type]
)
subset_lengths.append(n_items_in_split)
remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type]
# add 1 to all the lengths in round-robin fashion until the remainder is 0
for i in range(remainder):
idx_to_add_at = i % len(subset_lengths)
subset_lengths[idx_to_add_at] += 1
lengths = subset_lengths
for i, length in enumerate(lengths):
if length == 0:
warnings.warn(
f"Length of split at index {i} is 0. "
f"This might result in an empty dataset."
)
# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError(
"Sum of input lengths does not equal the length of the input dataset!"
)
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload]
lengths = cast(Sequence[int], lengths)
return [
Subset(dataset, indices[offset - length : offset])
for offset, length in zip(itertools.accumulate(lengths), lengths)
]