306 lines
12 KiB
Python
306 lines
12 KiB
Python
|
import torch
|
||
|
from torch import Tensor
|
||
|
|
||
|
from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union
|
||
|
|
||
|
__all__ = [
|
||
|
"BatchSampler",
|
||
|
"RandomSampler",
|
||
|
"Sampler",
|
||
|
"SequentialSampler",
|
||
|
"SubsetRandomSampler",
|
||
|
"WeightedRandomSampler",
|
||
|
]
|
||
|
|
||
|
T_co = TypeVar('T_co', covariant=True)
|
||
|
|
||
|
|
||
|
class Sampler(Generic[T_co]):
|
||
|
r"""Base class for all Samplers.
|
||
|
|
||
|
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
|
||
|
way to iterate over indices or lists of indices (batches) of dataset elements, and a :meth:`__len__` method
|
||
|
that returns the length of the returned iterators.
|
||
|
|
||
|
Args:
|
||
|
data_source (Dataset): This argument is not used and will be removed in 2.2.0.
|
||
|
You may still have custom implementation that utilizes it.
|
||
|
|
||
|
Example:
|
||
|
>>> # xdoctest: +SKIP
|
||
|
>>> class AccedingSequenceLengthSampler(Sampler[int]):
|
||
|
>>> def __init__(self, data: List[str]) -> None:
|
||
|
>>> self.data = data
|
||
|
>>>
|
||
|
>>> def __len__(self) -> int:
|
||
|
>>> return len(self.data)
|
||
|
>>>
|
||
|
>>> def __iter__(self) -> Iterator[int]:
|
||
|
>>> sizes = torch.tensor([len(x) for x in self.data])
|
||
|
>>> yield from torch.argsort(sizes).tolist()
|
||
|
>>>
|
||
|
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
|
||
|
>>> def __init__(self, data: List[str], batch_size: int) -> None:
|
||
|
>>> self.data = data
|
||
|
>>> self.batch_size = batch_size
|
||
|
>>>
|
||
|
>>> def __len__(self) -> int:
|
||
|
>>> return (len(self.data) + self.batch_size - 1) // self.batch_size
|
||
|
>>>
|
||
|
>>> def __iter__(self) -> Iterator[List[int]]:
|
||
|
>>> sizes = torch.tensor([len(x) for x in self.data])
|
||
|
>>> for batch in torch.chunk(torch.argsort(sizes), len(self)):
|
||
|
>>> yield batch.tolist()
|
||
|
|
||
|
.. note:: The :meth:`__len__` method isn't strictly required by
|
||
|
:class:`~torch.utils.data.DataLoader`, but is expected in any
|
||
|
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, data_source: Optional[Sized] = None) -> None:
|
||
|
if data_source is not None:
|
||
|
import warnings
|
||
|
|
||
|
warnings.warn("`data_source` argument is not used and will be removed in 2.2.0."
|
||
|
"You may still have custom implementation that utilizes it.")
|
||
|
|
||
|
def __iter__(self) -> Iterator[T_co]:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
# NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
|
||
|
#
|
||
|
# Many times we have an abstract class representing a collection/iterable of
|
||
|
# data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
|
||
|
# implementing a `__len__` method. In such cases, we must make sure to not
|
||
|
# provide a default implementation, because both straightforward default
|
||
|
# implementations have their issues:
|
||
|
#
|
||
|
# + `return NotImplemented`:
|
||
|
# Calling `len(subclass_instance)` raises:
|
||
|
# TypeError: 'NotImplementedType' object cannot be interpreted as an integer
|
||
|
#
|
||
|
# + `raise NotImplementedError()`:
|
||
|
# This prevents triggering some fallback behavior. E.g., the built-in
|
||
|
# `list(X)` tries to call `len(X)` first, and executes a different code
|
||
|
# path if the method is not found or `NotImplemented` is returned, while
|
||
|
# raising a `NotImplementedError` will propagate and make the call fail
|
||
|
# where it could have used `__iter__` to complete the call.
|
||
|
#
|
||
|
# Thus, the only two sensible things to do are
|
||
|
#
|
||
|
# + **not** provide a default `__len__`.
|
||
|
#
|
||
|
# + raise a `TypeError` instead, which is what Python uses when users call
|
||
|
# a method that is not defined on an object.
|
||
|
# (@ssnl verifies that this works on at least Python 3.7.)
|
||
|
|
||
|
|
||
|
class SequentialSampler(Sampler[int]):
|
||
|
r"""Samples elements sequentially, always in the same order.
|
||
|
|
||
|
Args:
|
||
|
data_source (Dataset): dataset to sample from
|
||
|
"""
|
||
|
|
||
|
data_source: Sized
|
||
|
|
||
|
def __init__(self, data_source: Sized) -> None:
|
||
|
self.data_source = data_source
|
||
|
|
||
|
def __iter__(self) -> Iterator[int]:
|
||
|
return iter(range(len(self.data_source)))
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self.data_source)
|
||
|
|
||
|
|
||
|
class RandomSampler(Sampler[int]):
|
||
|
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
|
||
|
|
||
|
If with replacement, then user can specify :attr:`num_samples` to draw.
|
||
|
|
||
|
Args:
|
||
|
data_source (Dataset): dataset to sample from
|
||
|
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
|
||
|
num_samples (int): number of samples to draw, default=`len(dataset)`.
|
||
|
generator (Generator): Generator used in sampling.
|
||
|
"""
|
||
|
|
||
|
data_source: Sized
|
||
|
replacement: bool
|
||
|
|
||
|
def __init__(self, data_source: Sized, replacement: bool = False,
|
||
|
num_samples: Optional[int] = None, generator=None) -> None:
|
||
|
self.data_source = data_source
|
||
|
self.replacement = replacement
|
||
|
self._num_samples = num_samples
|
||
|
self.generator = generator
|
||
|
|
||
|
if not isinstance(self.replacement, bool):
|
||
|
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
|
||
|
|
||
|
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
|
||
|
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
|
||
|
|
||
|
@property
|
||
|
def num_samples(self) -> int:
|
||
|
# dataset size might change at runtime
|
||
|
if self._num_samples is None:
|
||
|
return len(self.data_source)
|
||
|
return self._num_samples
|
||
|
|
||
|
def __iter__(self) -> Iterator[int]:
|
||
|
n = len(self.data_source)
|
||
|
if self.generator is None:
|
||
|
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
||
|
generator = torch.Generator()
|
||
|
generator.manual_seed(seed)
|
||
|
else:
|
||
|
generator = self.generator
|
||
|
|
||
|
if self.replacement:
|
||
|
for _ in range(self.num_samples // 32):
|
||
|
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
|
||
|
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
|
||
|
else:
|
||
|
for _ in range(self.num_samples // n):
|
||
|
yield from torch.randperm(n, generator=generator).tolist()
|
||
|
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return self.num_samples
|
||
|
|
||
|
|
||
|
class SubsetRandomSampler(Sampler[int]):
|
||
|
r"""Samples elements randomly from a given list of indices, without replacement.
|
||
|
|
||
|
Args:
|
||
|
indices (sequence): a sequence of indices
|
||
|
generator (Generator): Generator used in sampling.
|
||
|
"""
|
||
|
|
||
|
indices: Sequence[int]
|
||
|
|
||
|
def __init__(self, indices: Sequence[int], generator=None) -> None:
|
||
|
self.indices = indices
|
||
|
self.generator = generator
|
||
|
|
||
|
def __iter__(self) -> Iterator[int]:
|
||
|
for i in torch.randperm(len(self.indices), generator=self.generator):
|
||
|
yield self.indices[i]
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self.indices)
|
||
|
|
||
|
|
||
|
class WeightedRandomSampler(Sampler[int]):
|
||
|
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
|
||
|
|
||
|
Args:
|
||
|
weights (sequence) : a sequence of weights, not necessary summing up to one
|
||
|
num_samples (int): number of samples to draw
|
||
|
replacement (bool): if ``True``, samples are drawn with replacement.
|
||
|
If not, they are drawn without replacement, which means that when a
|
||
|
sample index is drawn for a row, it cannot be drawn again for that row.
|
||
|
generator (Generator): Generator used in sampling.
|
||
|
|
||
|
Example:
|
||
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||
|
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
|
||
|
[4, 4, 1, 4, 5]
|
||
|
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
|
||
|
[0, 1, 4, 3, 2]
|
||
|
"""
|
||
|
|
||
|
weights: Tensor
|
||
|
num_samples: int
|
||
|
replacement: bool
|
||
|
|
||
|
def __init__(self, weights: Sequence[float], num_samples: int,
|
||
|
replacement: bool = True, generator=None) -> None:
|
||
|
if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
|
||
|
num_samples <= 0:
|
||
|
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}")
|
||
|
if not isinstance(replacement, bool):
|
||
|
raise ValueError(f"replacement should be a boolean value, but got replacement={replacement}")
|
||
|
|
||
|
weights_tensor = torch.as_tensor(weights, dtype=torch.double)
|
||
|
if len(weights_tensor.shape) != 1:
|
||
|
raise ValueError("weights should be a 1d sequence but given "
|
||
|
f"weights have shape {tuple(weights_tensor.shape)}")
|
||
|
|
||
|
self.weights = weights_tensor
|
||
|
self.num_samples = num_samples
|
||
|
self.replacement = replacement
|
||
|
self.generator = generator
|
||
|
|
||
|
def __iter__(self) -> Iterator[int]:
|
||
|
rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
|
||
|
yield from iter(rand_tensor.tolist())
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return self.num_samples
|
||
|
|
||
|
|
||
|
class BatchSampler(Sampler[List[int]]):
|
||
|
r"""Wraps another sampler to yield a mini-batch of indices.
|
||
|
|
||
|
Args:
|
||
|
sampler (Sampler or Iterable): Base sampler. Can be any iterable object
|
||
|
batch_size (int): Size of mini-batch.
|
||
|
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
||
|
its size would be less than ``batch_size``
|
||
|
|
||
|
Example:
|
||
|
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
|
||
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
|
||
|
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
|
||
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
||
|
"""
|
||
|
|
||
|
def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
|
||
|
# Since collections.abc.Iterable does not check for `__getitem__`, which
|
||
|
# is one way for an object to be an iterable, we don't do an `isinstance`
|
||
|
# check here.
|
||
|
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
|
||
|
batch_size <= 0:
|
||
|
raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}")
|
||
|
if not isinstance(drop_last, bool):
|
||
|
raise ValueError(f"drop_last should be a boolean value, but got drop_last={drop_last}")
|
||
|
self.sampler = sampler
|
||
|
self.batch_size = batch_size
|
||
|
self.drop_last = drop_last
|
||
|
|
||
|
def __iter__(self) -> Iterator[List[int]]:
|
||
|
# Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
|
||
|
if self.drop_last:
|
||
|
sampler_iter = iter(self.sampler)
|
||
|
while True:
|
||
|
try:
|
||
|
batch = [next(sampler_iter) for _ in range(self.batch_size)]
|
||
|
yield batch
|
||
|
except StopIteration:
|
||
|
break
|
||
|
else:
|
||
|
batch = [0] * self.batch_size
|
||
|
idx_in_batch = 0
|
||
|
for idx in self.sampler:
|
||
|
batch[idx_in_batch] = idx
|
||
|
idx_in_batch += 1
|
||
|
if idx_in_batch == self.batch_size:
|
||
|
yield batch
|
||
|
idx_in_batch = 0
|
||
|
batch = [0] * self.batch_size
|
||
|
if idx_in_batch > 0:
|
||
|
yield batch[:idx_in_batch]
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
# Can only be called if self.sampler has __len__ implemented
|
||
|
# We cannot enforce this condition, so we turn off typechecking for the
|
||
|
# implementation below.
|
||
|
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
|
||
|
if self.drop_last:
|
||
|
return len(self.sampler) // self.batch_size # type: ignore[arg-type]
|
||
|
else:
|
||
|
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]
|