1480 lines
72 KiB
Python
1480 lines
72 KiB
Python
|
r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter.
|
||
|
|
||
|
To support these two classes, in `./_utils` we define many utility methods and
|
||
|
functions to be run in multiprocessing. E.g., the data loading worker loop is
|
||
|
in `./_utils/worker.py`.
|
||
|
"""
|
||
|
|
||
|
import functools
|
||
|
import itertools
|
||
|
import logging
|
||
|
import os
|
||
|
import queue
|
||
|
import threading
|
||
|
import warnings
|
||
|
|
||
|
from typing import Any, Callable, Iterable, TypeVar, Generic, List, Optional, Union
|
||
|
|
||
|
import multiprocessing as python_multiprocessing
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
import torch.multiprocessing as multiprocessing
|
||
|
import torch.utils.data.graph_settings
|
||
|
|
||
|
from torch._utils import ExceptionWrapper
|
||
|
|
||
|
from . import (
|
||
|
IterDataPipe,
|
||
|
MapDataPipe,
|
||
|
IterableDataset,
|
||
|
Sampler,
|
||
|
SequentialSampler,
|
||
|
RandomSampler,
|
||
|
BatchSampler,
|
||
|
Dataset,)
|
||
|
|
||
|
from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper
|
||
|
|
||
|
from . import _utils
|
||
|
|
||
|
__all__ = [
|
||
|
"DataLoader",
|
||
|
"get_worker_info",
|
||
|
"default_collate",
|
||
|
"default_convert",
|
||
|
]
|
||
|
|
||
|
T_co = TypeVar('T_co', covariant=True)
|
||
|
T = TypeVar('T')
|
||
|
_worker_init_fn_t = Callable[[int], None]
|
||
|
|
||
|
# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
|
||
|
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
|
||
|
# See https://github.com/python/mypy/issues/3737.
|
||
|
_collate_fn_t = Callable[[List[T]], Any]
|
||
|
|
||
|
|
||
|
# These functions used to be defined in this file. However, it was moved to
|
||
|
# _utils/collate.py. Although it is rather hard to access this from user land
|
||
|
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
|
||
|
# probably is user code out there using it. This aliasing maintains BC in this
|
||
|
# aspect.
|
||
|
default_collate: _collate_fn_t = _utils.collate.default_collate
|
||
|
default_convert = _utils.collate.default_convert
|
||
|
|
||
|
get_worker_info = _utils.worker.get_worker_info
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class _DatasetKind:
|
||
|
Map = 0
|
||
|
Iterable = 1
|
||
|
|
||
|
@staticmethod
|
||
|
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
|
||
|
if kind == _DatasetKind.Map:
|
||
|
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
|
||
|
else:
|
||
|
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
|
||
|
|
||
|
|
||
|
class _InfiniteConstantSampler(Sampler):
|
||
|
r"""Analogous to ``itertools.repeat(None, None)``.
|
||
|
|
||
|
Used as sampler for :class:`~torch.utils.data.IterableDataset`.
|
||
|
"""
|
||
|
|
||
|
def __iter__(self):
|
||
|
while True:
|
||
|
yield None
|
||
|
|
||
|
|
||
|
def _get_distributed_settings():
|
||
|
if dist.is_available() and dist.is_initialized():
|
||
|
return dist.get_world_size(), dist.get_rank()
|
||
|
else:
|
||
|
return 1, 0
|
||
|
|
||
|
|
||
|
def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
|
||
|
global_worker_id = worker_id
|
||
|
info = torch.utils.data.get_worker_info()
|
||
|
assert info is not None
|
||
|
total_workers = info.num_workers
|
||
|
datapipe = info.dataset
|
||
|
assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
|
||
|
# To distribute elements across distributed process evenly, we should shard data on distributed
|
||
|
# processes first then shard on worker processes
|
||
|
total_workers *= world_size
|
||
|
global_worker_id = global_worker_id * world_size + rank_id
|
||
|
# For BC, use default SHARDING_PRIORITIES
|
||
|
torch.utils.data.graph_settings.apply_sharding(datapipe, total_workers, global_worker_id)
|
||
|
if worker_init_fn is not None:
|
||
|
worker_init_fn(worker_id)
|
||
|
|
||
|
|
||
|
def _share_dist_seed(generator, pg):
|
||
|
_shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
|
||
|
if isinstance(pg, dist.ProcessGroup):
|
||
|
dist.broadcast(_shared_seed, src=0, group=pg)
|
||
|
return _shared_seed.item()
|
||
|
|
||
|
|
||
|
class DataLoader(Generic[T_co]):
|
||
|
r"""
|
||
|
Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.
|
||
|
|
||
|
The :class:`~torch.utils.data.DataLoader` supports both map-style and
|
||
|
iterable-style datasets with single- or multi-process loading, customizing
|
||
|
loading order and optional automatic batching (collation) and memory pinning.
|
||
|
|
||
|
See :py:mod:`torch.utils.data` documentation page for more details.
|
||
|
|
||
|
Args:
|
||
|
dataset (Dataset): dataset from which to load the data.
|
||
|
batch_size (int, optional): how many samples per batch to load
|
||
|
(default: ``1``).
|
||
|
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
||
|
at every epoch (default: ``False``).
|
||
|
sampler (Sampler or Iterable, optional): defines the strategy to draw
|
||
|
samples from the dataset. Can be any ``Iterable`` with ``__len__``
|
||
|
implemented. If specified, :attr:`shuffle` must not be specified.
|
||
|
batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
|
||
|
returns a batch of indices at a time. Mutually exclusive with
|
||
|
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
|
||
|
and :attr:`drop_last`.
|
||
|
num_workers (int, optional): how many subprocesses to use for data
|
||
|
loading. ``0`` means that the data will be loaded in the main process.
|
||
|
(default: ``0``)
|
||
|
collate_fn (Callable, optional): merges a list of samples to form a
|
||
|
mini-batch of Tensor(s). Used when using batched loading from a
|
||
|
map-style dataset.
|
||
|
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
|
||
|
into device/CUDA pinned memory before returning them. If your data elements
|
||
|
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
|
||
|
see the example below.
|
||
|
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
||
|
if the dataset size is not divisible by the batch size. If ``False`` and
|
||
|
the size of dataset is not divisible by the batch size, then the last batch
|
||
|
will be smaller. (default: ``False``)
|
||
|
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
||
|
from workers. Should always be non-negative. (default: ``0``)
|
||
|
worker_init_fn (Callable, optional): If not ``None``, this will be called on each
|
||
|
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
||
|
input, after seeding and before data loading. (default: ``None``)
|
||
|
multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
|
||
|
``None``, the default `multiprocessing context`_ of your operating system will
|
||
|
be used. (default: ``None``)
|
||
|
generator (torch.Generator, optional): If not ``None``, this RNG will be used
|
||
|
by RandomSampler to generate random indexes and multiprocessing to generate
|
||
|
``base_seed`` for workers. (default: ``None``)
|
||
|
prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
|
||
|
in advance by each worker. ``2`` means there will be a total of
|
||
|
2 * num_workers batches prefetched across all workers. (default value depends
|
||
|
on the set value for num_workers. If value of num_workers=0 default is ``None``.
|
||
|
Otherwise, if value of ``num_workers > 0`` default is ``2``).
|
||
|
persistent_workers (bool, optional): If ``True``, the data loader will not shut down
|
||
|
the worker processes after a dataset has been consumed once. This allows to
|
||
|
maintain the workers `Dataset` instances alive. (default: ``False``)
|
||
|
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
|
||
|
``True``.
|
||
|
|
||
|
|
||
|
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
|
||
|
cannot be an unpicklable object, e.g., a lambda function. See
|
||
|
:ref:`multiprocessing-best-practices` on more details related
|
||
|
to multiprocessing in PyTorch.
|
||
|
|
||
|
.. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
|
||
|
When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
|
||
|
it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
|
||
|
rounding depending on :attr:`drop_last`, regardless of multi-process loading
|
||
|
configurations. This represents the best guess PyTorch can make because PyTorch
|
||
|
trusts user :attr:`dataset` code in correctly handling multi-process
|
||
|
loading to avoid duplicate data.
|
||
|
|
||
|
However, if sharding results in multiple workers having incomplete last batches,
|
||
|
this estimate can still be inaccurate, because (1) an otherwise complete batch can
|
||
|
be broken into multiple ones and (2) more than one batch worth of samples can be
|
||
|
dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
|
||
|
cases in general.
|
||
|
|
||
|
See `Dataset Types`_ for more details on these two types of datasets and how
|
||
|
:class:`~torch.utils.data.IterableDataset` interacts with
|
||
|
`Multi-process data loading`_.
|
||
|
|
||
|
.. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
|
||
|
:ref:`data-loading-randomness` notes for random seed related questions.
|
||
|
|
||
|
.. _multiprocessing context:
|
||
|
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
|
||
|
"""
|
||
|
|
||
|
dataset: Dataset[T_co]
|
||
|
batch_size: Optional[int]
|
||
|
num_workers: int
|
||
|
pin_memory: bool
|
||
|
drop_last: bool
|
||
|
timeout: float
|
||
|
sampler: Union[Sampler, Iterable]
|
||
|
pin_memory_device: str
|
||
|
prefetch_factor: Optional[int]
|
||
|
_iterator : Optional['_BaseDataLoaderIter']
|
||
|
__initialized = False
|
||
|
|
||
|
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
|
||
|
shuffle: Optional[bool] = None, sampler: Union[Sampler, Iterable, None] = None,
|
||
|
batch_sampler: Union[Sampler[List], Iterable[List], None] = None,
|
||
|
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
|
||
|
pin_memory: bool = False, drop_last: bool = False,
|
||
|
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
|
||
|
multiprocessing_context=None, generator=None,
|
||
|
*, prefetch_factor: Optional[int] = None,
|
||
|
persistent_workers: bool = False,
|
||
|
pin_memory_device: str = ""):
|
||
|
torch._C._log_api_usage_once("python.data_loader")
|
||
|
|
||
|
if num_workers < 0:
|
||
|
raise ValueError('num_workers option should be non-negative; '
|
||
|
'use num_workers=0 to disable multiprocessing.')
|
||
|
|
||
|
if timeout < 0:
|
||
|
raise ValueError('timeout option should be non-negative')
|
||
|
|
||
|
if num_workers == 0 and prefetch_factor is not None:
|
||
|
raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
|
||
|
'let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.')
|
||
|
elif num_workers > 0 and prefetch_factor is None:
|
||
|
prefetch_factor = 2
|
||
|
elif prefetch_factor is not None and prefetch_factor < 0:
|
||
|
raise ValueError('prefetch_factor option should be non-negative')
|
||
|
|
||
|
if persistent_workers and num_workers == 0:
|
||
|
raise ValueError('persistent_workers option needs num_workers > 0')
|
||
|
|
||
|
self.dataset = dataset
|
||
|
self.num_workers = num_workers
|
||
|
self.prefetch_factor = prefetch_factor
|
||
|
self.pin_memory = pin_memory
|
||
|
self.pin_memory_device = pin_memory_device
|
||
|
self.timeout = timeout
|
||
|
self.worker_init_fn = worker_init_fn
|
||
|
self.multiprocessing_context = multiprocessing_context
|
||
|
|
||
|
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
||
|
# _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
|
||
|
if isinstance(self.dataset, IterDataPipe):
|
||
|
self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
|
||
|
elif isinstance(self.dataset, MapDataPipe):
|
||
|
self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
|
||
|
|
||
|
# Arg-check dataset related before checking samplers because we want to
|
||
|
# tell users that iterable-style datasets are incompatible with custom
|
||
|
# samplers first, so that they don't learn that this combo doesn't work
|
||
|
# after spending time fixing the custom sampler errors.
|
||
|
if isinstance(dataset, IterableDataset):
|
||
|
self._dataset_kind = _DatasetKind.Iterable
|
||
|
# NOTE [ Custom Samplers and IterableDataset ]
|
||
|
#
|
||
|
# `IterableDataset` does not support custom `batch_sampler` or
|
||
|
# `sampler` since the key is irrelevant (unless we support
|
||
|
# generator-style dataset one day...).
|
||
|
#
|
||
|
# For `sampler`, we always create a dummy sampler. This is an
|
||
|
# infinite sampler even when the dataset may have an implemented
|
||
|
# finite `__len__` because in multi-process data loading, naive
|
||
|
# settings will return duplicated data (which may be desired), and
|
||
|
# thus using a sampler with length matching that of dataset will
|
||
|
# cause data lost (you may have duplicates of the first couple
|
||
|
# batches, but never see anything afterwards). Therefore,
|
||
|
# `Iterabledataset` always uses an infinite sampler, an instance of
|
||
|
# `_InfiniteConstantSampler` defined above.
|
||
|
#
|
||
|
# A custom `batch_sampler` essentially only controls the batch size.
|
||
|
# However, it is unclear how useful it would be since an iterable-style
|
||
|
# dataset can handle that within itself. Moreover, it is pointless
|
||
|
# in multi-process data loading as the assignment order of batches
|
||
|
# to workers is an implementation detail so users can not control
|
||
|
# how to batchify each worker's iterable. Thus, we disable this
|
||
|
# option. If this turns out to be useful in future, we can re-enable
|
||
|
# this, and support custom samplers that specify the assignments to
|
||
|
# specific workers.
|
||
|
if isinstance(dataset, IterDataPipe):
|
||
|
if shuffle is not None:
|
||
|
dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
|
||
|
# We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
|
||
|
elif shuffle not in {False, None}:
|
||
|
raise ValueError(
|
||
|
f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}")
|
||
|
|
||
|
if sampler is not None:
|
||
|
# See NOTE [ Custom Samplers and IterableDataset ]
|
||
|
raise ValueError(
|
||
|
f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}")
|
||
|
elif batch_sampler is not None:
|
||
|
# See NOTE [ Custom Samplers and IterableDataset ]
|
||
|
raise ValueError(
|
||
|
"DataLoader with IterableDataset: expected unspecified "
|
||
|
f"batch_sampler option, but got batch_sampler={batch_sampler}")
|
||
|
else:
|
||
|
shuffle = bool(shuffle)
|
||
|
self._dataset_kind = _DatasetKind.Map
|
||
|
|
||
|
|
||
|
|
||
|
if sampler is not None and shuffle:
|
||
|
raise ValueError('sampler option is mutually exclusive with '
|
||
|
'shuffle')
|
||
|
|
||
|
if batch_sampler is not None:
|
||
|
# auto_collation with custom batch_sampler
|
||
|
if batch_size != 1 or shuffle or sampler is not None or drop_last:
|
||
|
raise ValueError('batch_sampler option is mutually exclusive '
|
||
|
'with batch_size, shuffle, sampler, and '
|
||
|
'drop_last')
|
||
|
batch_size = None
|
||
|
drop_last = False
|
||
|
elif batch_size is None:
|
||
|
# no auto_collation
|
||
|
if drop_last:
|
||
|
raise ValueError('batch_size=None option disables auto-batching '
|
||
|
'and is mutually exclusive with drop_last')
|
||
|
|
||
|
if sampler is None: # give default samplers
|
||
|
if self._dataset_kind == _DatasetKind.Iterable:
|
||
|
# See NOTE [ Custom Samplers and IterableDataset ]
|
||
|
sampler = _InfiniteConstantSampler()
|
||
|
else: # map-style
|
||
|
if shuffle:
|
||
|
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
|
||
|
else:
|
||
|
sampler = SequentialSampler(dataset) # type: ignore[arg-type]
|
||
|
|
||
|
if batch_size is not None and batch_sampler is None:
|
||
|
# auto_collation without custom batch_sampler
|
||
|
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
|
||
|
|
||
|
self.batch_size = batch_size
|
||
|
self.drop_last = drop_last
|
||
|
self.sampler = sampler
|
||
|
self.batch_sampler = batch_sampler
|
||
|
self.generator = generator
|
||
|
|
||
|
if collate_fn is None:
|
||
|
if self._auto_collation:
|
||
|
collate_fn = _utils.collate.default_collate
|
||
|
else:
|
||
|
collate_fn = _utils.collate.default_convert
|
||
|
|
||
|
self.collate_fn = collate_fn
|
||
|
self.persistent_workers = persistent_workers
|
||
|
|
||
|
self.__initialized = True
|
||
|
self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]
|
||
|
|
||
|
self._iterator = None
|
||
|
|
||
|
self.check_worker_number_rationality()
|
||
|
|
||
|
torch.set_vital('Dataloader', 'enabled', 'True') # type: ignore[attr-defined]
|
||
|
|
||
|
def _get_iterator(self) -> '_BaseDataLoaderIter':
|
||
|
if self.num_workers == 0:
|
||
|
return _SingleProcessDataLoaderIter(self)
|
||
|
else:
|
||
|
self.check_worker_number_rationality()
|
||
|
return _MultiProcessingDataLoaderIter(self)
|
||
|
|
||
|
@property
|
||
|
def multiprocessing_context(self):
|
||
|
return self.__multiprocessing_context
|
||
|
|
||
|
@multiprocessing_context.setter
|
||
|
def multiprocessing_context(self, multiprocessing_context):
|
||
|
if multiprocessing_context is not None:
|
||
|
if self.num_workers > 0:
|
||
|
if isinstance(multiprocessing_context, str):
|
||
|
valid_start_methods = multiprocessing.get_all_start_methods()
|
||
|
if multiprocessing_context not in valid_start_methods:
|
||
|
raise ValueError(
|
||
|
'multiprocessing_context option '
|
||
|
f'should specify a valid start method in {valid_start_methods!r}, but got '
|
||
|
f'multiprocessing_context={multiprocessing_context!r}')
|
||
|
multiprocessing_context = multiprocessing.get_context(multiprocessing_context)
|
||
|
|
||
|
if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
|
||
|
raise TypeError('multiprocessing_context option should be a valid context '
|
||
|
'object or a string specifying the start method, but got '
|
||
|
f'multiprocessing_context={multiprocessing_context}')
|
||
|
else:
|
||
|
raise ValueError('multiprocessing_context can only be used with '
|
||
|
'multi-process loading (num_workers > 0), but got '
|
||
|
f'num_workers={self.num_workers}')
|
||
|
|
||
|
self.__multiprocessing_context = multiprocessing_context
|
||
|
|
||
|
def __setattr__(self, attr, val):
|
||
|
if self.__initialized and attr in (
|
||
|
'batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset', 'persistent_workers'):
|
||
|
raise ValueError(f'{attr} attribute should not be set after {self.__class__.__name__} is initialized')
|
||
|
|
||
|
super().__setattr__(attr, val)
|
||
|
|
||
|
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
|
||
|
# since '_BaseDataLoaderIter' references 'DataLoader'.
|
||
|
def __iter__(self) -> '_BaseDataLoaderIter':
|
||
|
# When using a single worker the returned iterator should be
|
||
|
# created everytime to avoid resetting its state
|
||
|
# However, in the case of a multiple workers iterator
|
||
|
# the iterator is only created once in the lifetime of the
|
||
|
# DataLoader object so that workers can be reused
|
||
|
if self.persistent_workers and self.num_workers > 0:
|
||
|
if self._iterator is None:
|
||
|
self._iterator = self._get_iterator()
|
||
|
else:
|
||
|
self._iterator._reset(self)
|
||
|
return self._iterator
|
||
|
else:
|
||
|
return self._get_iterator()
|
||
|
|
||
|
@property
|
||
|
def _auto_collation(self):
|
||
|
return self.batch_sampler is not None
|
||
|
|
||
|
@property
|
||
|
def _index_sampler(self):
|
||
|
# The actual sampler used for generating indices for `_DatasetFetcher`
|
||
|
# (see _utils/fetch.py) to read data at each time. This would be
|
||
|
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
|
||
|
# We can't change `.sampler` and `.batch_sampler` attributes for BC
|
||
|
# reasons.
|
||
|
if self._auto_collation:
|
||
|
return self.batch_sampler
|
||
|
else:
|
||
|
return self.sampler
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
if self._dataset_kind == _DatasetKind.Iterable:
|
||
|
# NOTE [ IterableDataset and __len__ ]
|
||
|
#
|
||
|
# For `IterableDataset`, `__len__` could be inaccurate when one naively
|
||
|
# does multi-processing data loading, since the samples will be duplicated.
|
||
|
# However, no real use case should be actually using that behavior, so
|
||
|
# it should count as a user error. We should generally trust user
|
||
|
# code to do the proper thing (e.g., configure each replica differently
|
||
|
# in `__iter__`), and give us the correct `__len__` if they choose to
|
||
|
# implement it (this will still throw if the dataset does not implement
|
||
|
# a `__len__`).
|
||
|
#
|
||
|
# To provide a further warning, we track if `__len__` was called on the
|
||
|
# `DataLoader`, save the returned value in `self._len_called`, and warn
|
||
|
# if the iterator ends up yielding more than this number of samples.
|
||
|
|
||
|
# Cannot statically verify that dataset is Sized
|
||
|
length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
|
||
|
if self.batch_size is not None: # IterableDataset doesn't allow custom sampler or batch_sampler
|
||
|
from math import ceil
|
||
|
if self.drop_last:
|
||
|
length = length // self.batch_size
|
||
|
else:
|
||
|
length = ceil(length / self.batch_size)
|
||
|
return length
|
||
|
else:
|
||
|
return len(self._index_sampler)
|
||
|
|
||
|
def check_worker_number_rationality(self):
|
||
|
# This function check whether the dataloader's worker number is rational based on
|
||
|
# current system's resource. Current rule is that if the number of workers this
|
||
|
# Dataloader will create is bigger than the number of logical cpus that is allowed to
|
||
|
# use, than we will pop up a warning to let user pay attention.
|
||
|
#
|
||
|
# eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
|
||
|
# threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
|
||
|
# DataLoader process can use half of them which is 32, then the rational max number of
|
||
|
# worker that initiated from this process is 32.
|
||
|
# Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
|
||
|
# So the warning message is triggered to notify the user to lower the worker number if
|
||
|
# necessary.
|
||
|
#
|
||
|
#
|
||
|
# [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
|
||
|
# available (available in most of Linux system, but not OSX and Windows).
|
||
|
# When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
|
||
|
# it doesn't repect cpuset.
|
||
|
# We don't take threading into account since each worker process is single threaded
|
||
|
# at this time.
|
||
|
#
|
||
|
# We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
|
||
|
# other than `torch.set_num_threads` to 1 in the worker process, if the passing
|
||
|
# in functions use 3rd party modules that rely on those threading flags to determine
|
||
|
# how many thread to create (eg. numpy, etc), then it is caller's responsibility to
|
||
|
# set those flags correctly.
|
||
|
def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
|
||
|
|
||
|
suggested_max_worker_msg = ((
|
||
|
"Our suggested max number of worker in current system is {}{}, which is smaller "
|
||
|
"than what this DataLoader is going to create.").format(
|
||
|
num_worker_suggest,
|
||
|
("" if cpuset_checked else " (`cpuset` is not taken into account)"))
|
||
|
) if num_worker_suggest is not None else (
|
||
|
"DataLoader is not able to compute a suggested max number of worker in current system.")
|
||
|
|
||
|
warn_msg = (
|
||
|
"This DataLoader will create {} worker processes in total. {} "
|
||
|
"Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
|
||
|
"lower the worker number to avoid potential slowness/freeze if necessary.").format(
|
||
|
num_worker_created,
|
||
|
suggested_max_worker_msg)
|
||
|
return warn_msg
|
||
|
|
||
|
if not self.num_workers or self.num_workers == 0:
|
||
|
return
|
||
|
|
||
|
# try to compute a suggested max number of worker based on system's resource
|
||
|
max_num_worker_suggest = None
|
||
|
cpuset_checked = False
|
||
|
if hasattr(os, 'sched_getaffinity'):
|
||
|
try:
|
||
|
max_num_worker_suggest = len(os.sched_getaffinity(0))
|
||
|
cpuset_checked = True
|
||
|
except Exception:
|
||
|
pass
|
||
|
if max_num_worker_suggest is None:
|
||
|
# os.cpu_count() could return Optional[int]
|
||
|
# get cpu count first and check None in order to satisfy mypy check
|
||
|
cpu_count = os.cpu_count()
|
||
|
if cpu_count is not None:
|
||
|
max_num_worker_suggest = cpu_count
|
||
|
|
||
|
if max_num_worker_suggest is None:
|
||
|
warnings.warn(_create_warning_msg(
|
||
|
max_num_worker_suggest,
|
||
|
self.num_workers,
|
||
|
cpuset_checked))
|
||
|
return
|
||
|
|
||
|
if self.num_workers > max_num_worker_suggest:
|
||
|
warnings.warn(_create_warning_msg(
|
||
|
max_num_worker_suggest,
|
||
|
self.num_workers,
|
||
|
cpuset_checked))
|
||
|
|
||
|
|
||
|
class _BaseDataLoaderIter:
|
||
|
def __init__(self, loader: DataLoader) -> None:
|
||
|
self._dataset = loader.dataset
|
||
|
self._shared_seed = None
|
||
|
self._pg = None
|
||
|
if isinstance(self._dataset, IterDataPipe):
|
||
|
if dist.is_available() and dist.is_initialized():
|
||
|
self._pg = dist.new_group(backend="gloo")
|
||
|
self._shared_seed = _share_dist_seed(loader.generator, self._pg)
|
||
|
shared_rng = torch.Generator()
|
||
|
shared_rng.manual_seed(self._shared_seed)
|
||
|
self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
|
||
|
self._dataset_kind = loader._dataset_kind
|
||
|
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
||
|
self._auto_collation = loader._auto_collation
|
||
|
self._drop_last = loader.drop_last
|
||
|
self._index_sampler = loader._index_sampler
|
||
|
self._num_workers = loader.num_workers
|
||
|
ws, rank = _get_distributed_settings()
|
||
|
self._world_size = ws
|
||
|
self._rank = rank
|
||
|
# for other backends, pin_memory_device need to set. if not set
|
||
|
# default behaviour is CUDA device. if pin_memory_device is selected
|
||
|
# and pin_memory is not set, the default behaviour false.
|
||
|
if (len(loader.pin_memory_device) == 0):
|
||
|
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
|
||
|
self._pin_memory_device = None
|
||
|
else:
|
||
|
if not loader.pin_memory:
|
||
|
warn_msg = ("pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
|
||
|
"please set pin_memory to true, if you need to use the device pin memory")
|
||
|
warnings.warn(warn_msg)
|
||
|
|
||
|
self._pin_memory = loader.pin_memory
|
||
|
self._pin_memory_device = loader.pin_memory_device
|
||
|
self._timeout = loader.timeout
|
||
|
self._collate_fn = loader.collate_fn
|
||
|
self._sampler_iter = iter(self._index_sampler)
|
||
|
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
|
||
|
self._persistent_workers = loader.persistent_workers
|
||
|
self._num_yielded = 0
|
||
|
self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__"
|
||
|
|
||
|
def __iter__(self) -> '_BaseDataLoaderIter':
|
||
|
return self
|
||
|
|
||
|
def _reset(self, loader, first_iter=False):
|
||
|
self._sampler_iter = iter(self._index_sampler)
|
||
|
self._num_yielded = 0
|
||
|
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
||
|
if isinstance(self._dataset, IterDataPipe):
|
||
|
self._shared_seed = _share_dist_seed(loader.generator, self._pg)
|
||
|
shared_rng = torch.Generator()
|
||
|
shared_rng.manual_seed(self._shared_seed)
|
||
|
self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
|
||
|
|
||
|
def _next_index(self):
|
||
|
return next(self._sampler_iter) # may raise StopIteration
|
||
|
|
||
|
def _next_data(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def __next__(self) -> Any:
|
||
|
with torch.autograd.profiler.record_function(self._profile_name):
|
||
|
if self._sampler_iter is None:
|
||
|
# TODO(https://github.com/pytorch/pytorch/issues/76750)
|
||
|
self._reset() # type: ignore[call-arg]
|
||
|
data = self._next_data()
|
||
|
self._num_yielded += 1
|
||
|
if self._dataset_kind == _DatasetKind.Iterable and \
|
||
|
self._IterableDataset_len_called is not None and \
|
||
|
self._num_yielded > self._IterableDataset_len_called:
|
||
|
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
|
||
|
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
|
||
|
self._num_yielded)
|
||
|
if self._num_workers > 0:
|
||
|
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
|
||
|
"IterableDataset replica at each worker. Please see "
|
||
|
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
|
||
|
warnings.warn(warn_msg)
|
||
|
return data
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self._index_sampler)
|
||
|
|
||
|
def __getstate__(self):
|
||
|
# TODO: add limited pickling support for sharing an iterator
|
||
|
# across multiple threads for HOGWILD.
|
||
|
# Probably the best way to do this is by moving the sample pushing
|
||
|
# to a separate thread and then just sharing the data queue
|
||
|
# but signalling the end is tricky without a non-blocking API
|
||
|
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
|
||
|
|
||
|
|
||
|
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
|
||
|
def __init__(self, loader):
|
||
|
super().__init__(loader)
|
||
|
assert self._timeout == 0
|
||
|
assert self._num_workers == 0
|
||
|
|
||
|
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
||
|
# Taking care of distributed sharding
|
||
|
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
|
||
|
# For BC, use default SHARDING_PRIORITIES
|
||
|
torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank)
|
||
|
|
||
|
self._dataset_fetcher = _DatasetKind.create_fetcher(
|
||
|
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
|
||
|
|
||
|
def _next_data(self):
|
||
|
index = self._next_index() # may raise StopIteration
|
||
|
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
|
||
|
if self._pin_memory:
|
||
|
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
|
||
|
return data
|
||
|
|
||
|
|
||
|
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||
|
r"""Iterates once over the DataLoader's dataset, as specified by the sampler."""
|
||
|
|
||
|
# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
|
||
|
#
|
||
|
# Preliminary:
|
||
|
#
|
||
|
# Our data model looks like this (queues are indicated with curly brackets):
|
||
|
#
|
||
|
# main process ||
|
||
|
# | ||
|
||
|
# {index_queue} ||
|
||
|
# | ||
|
||
|
# worker processes || DATA
|
||
|
# | ||
|
||
|
# {worker_result_queue} || FLOW
|
||
|
# | ||
|
||
|
# pin_memory_thread of main process || DIRECTION
|
||
|
# | ||
|
||
|
# {data_queue} ||
|
||
|
# | ||
|
||
|
# data output \/
|
||
|
#
|
||
|
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
|
||
|
# `pin_memory=False`.
|
||
|
#
|
||
|
#
|
||
|
# Terminating multiprocessing logic requires very careful design. In
|
||
|
# particular, we need to make sure that
|
||
|
#
|
||
|
# 1. The iterator gracefully exits the workers when its last reference is
|
||
|
# gone or it is depleted.
|
||
|
#
|
||
|
# In this case, the workers should be gracefully exited because the
|
||
|
# main process may still need to continue to run, and we want cleaning
|
||
|
# up code in the workers to be executed (e.g., releasing GPU memory).
|
||
|
# Naturally, we implement the shutdown logic in `__del__` of
|
||
|
# DataLoaderIterator.
|
||
|
#
|
||
|
# We delay the discussion on the logic in this case until later.
|
||
|
#
|
||
|
# 2. The iterator exits the workers when the loader process and/or worker
|
||
|
# processes exits normally or with error.
|
||
|
#
|
||
|
# We set all workers and `pin_memory_thread` to have `daemon=True`.
|
||
|
#
|
||
|
# You may ask, why can't we make the workers non-daemonic, and
|
||
|
# gracefully exit using the same logic as we have in `__del__` when the
|
||
|
# iterator gets deleted (see 1 above)?
|
||
|
#
|
||
|
# First of all, `__del__` is **not** guaranteed to be called when
|
||
|
# interpreter exits. Even if it is called, by the time it executes,
|
||
|
# many Python core library resources may already be freed, and even
|
||
|
# simple things like acquiring an internal lock of a queue may hang.
|
||
|
# Therefore, in this case, we actually need to prevent `__del__` from
|
||
|
# being executed, and rely on the automatic termination of daemonic
|
||
|
# children.
|
||
|
#
|
||
|
# Thus, we register an `atexit` hook that sets a global flag
|
||
|
# `_utils.python_exit_status`. Since `atexit` hooks are executed in the
|
||
|
# reverse order of registration, we are guaranteed that this flag is
|
||
|
# set before library resources we use are freed (which, at least in
|
||
|
# CPython, is done via an `atexit` handler defined in
|
||
|
# `multiprocessing/util.py`
|
||
|
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
|
||
|
# registered when an object requiring this mechanism is first
|
||
|
# created, e.g., `mp.Queue`
|
||
|
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
|
||
|
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
|
||
|
# )
|
||
|
#
|
||
|
# So in `__del__`, we check if `_utils.python_exit_status` is set or
|
||
|
# `None` (freed), and perform no-op if so.
|
||
|
#
|
||
|
# However, simply letting library clean-up codes run can also be bad,
|
||
|
# because such codes (i.e., `multiprocessing.util._exit_function()`)
|
||
|
# include join putting threads for `mp.Queue`, which can be blocking.
|
||
|
# Hence, the main process putting threads are called with
|
||
|
# `cancel_join_thread` at creation. See later section
|
||
|
# [ 3b. A process won't hang when putting into a queue; ]
|
||
|
# for more details.
|
||
|
#
|
||
|
# Here are two example cases where library clean-up codes can run
|
||
|
# before `__del__` is called:
|
||
|
#
|
||
|
# 1. If we hold onto a reference to the iterator, it more often
|
||
|
# than not tries to do `multiprocessing` library cleaning before
|
||
|
# clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
|
||
|
# and thus prevents our cleaning-up code to run first.
|
||
|
#
|
||
|
# 2. A similar issue araises when a `DataLoader` is used in a subprocess.
|
||
|
# When a process ends, it shuts the all its daemonic children
|
||
|
# down with a SIGTERM (instead of joining them without a timeout).
|
||
|
# Simiarly for threads, but by a different mechanism. This fact,
|
||
|
# together with a few implementation details of multiprocessing, forces
|
||
|
# us to make workers daemonic. All of our problems arise when a
|
||
|
# DataLoader is used in a subprocess, and are caused by multiprocessing
|
||
|
# code which looks more or less like this:
|
||
|
#
|
||
|
# try:
|
||
|
# your_function_using_a_dataloader()
|
||
|
# finally:
|
||
|
# multiprocessing.util._exit_function()
|
||
|
#
|
||
|
# The joining/termination mentioned above happens inside
|
||
|
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
|
||
|
# throws, the stack trace stored in the exception will prevent the
|
||
|
# frame which uses `DataLoaderIter` to be freed. If the frame has any
|
||
|
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
|
||
|
# its `__del__`, which starts the shutdown procedure, will not be
|
||
|
# called. That, in turn, means that workers aren't notified. Attempting
|
||
|
# to join in `_exit_function` will then result in a hang.
|
||
|
#
|
||
|
# For context, `_exit_function` is also registered as an `atexit` call.
|
||
|
# So it is unclear to me (@ssnl) why this is needed in a finally block.
|
||
|
# The code dates back to 2008 and there is no comment on the original
|
||
|
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
|
||
|
# the finally block and the `atexit` registration) that explains this.
|
||
|
#
|
||
|
#
|
||
|
# Finally, another choice is to just shutdown workers with logic in 1
|
||
|
# above whenever we see an error in `next`. This isn't ideal because
|
||
|
# a. It prevents users from using try-catch to resume data loading.
|
||
|
# b. It doesn't prevent hanging if users have references to the
|
||
|
# iterator.
|
||
|
#
|
||
|
# 3. All processes exit if any of them die unexpectedly by fatal signals.
|
||
|
#
|
||
|
# As shown above, the workers are set as daemonic children of the main
|
||
|
# process. However, automatic cleaning-up of such child processes only
|
||
|
# happens if the parent process exits gracefully (e.g., not via fatal
|
||
|
# signals like SIGKILL). So we must ensure that each process will exit
|
||
|
# even the process that should send/receive data to/from it were
|
||
|
# killed, i.e.,
|
||
|
#
|
||
|
# a. A process won't hang when getting from a queue.
|
||
|
#
|
||
|
# Even with carefully designed data dependencies (i.e., a `put()`
|
||
|
# always corresponding to a `get()`), hanging on `get()` can still
|
||
|
# happen when data in queue is corrupted (e.g., due to
|
||
|
# `cancel_join_thread` or unexpected exit).
|
||
|
#
|
||
|
# For child exit, we set a timeout whenever we try to get data
|
||
|
# from `data_queue`, and check the workers' status on each timeout
|
||
|
# and error.
|
||
|
# See `_DataLoaderiter._get_batch()` and
|
||
|
# `_DataLoaderiter._try_get_data()` for details.
|
||
|
#
|
||
|
# Additionally, for child exit on non-Windows platforms, we also
|
||
|
# register a SIGCHLD handler (which is supported on Windows) on
|
||
|
# the main process, which checks if any of the workers fail in the
|
||
|
# (Python) handler. This is more efficient and faster in detecting
|
||
|
# worker failures, compared to only using the above mechanism.
|
||
|
# See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
|
||
|
#
|
||
|
# For `.get()` calls where the sender(s) is not the workers, we
|
||
|
# guard them with timeouts, and check the status of the sender
|
||
|
# when timeout happens:
|
||
|
# + in the workers, the `_utils.worker.ManagerWatchdog` class
|
||
|
# checks the status of the main process.
|
||
|
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
|
||
|
# check `pin_memory_thread` status periodically until `.get()`
|
||
|
# returns or see that `pin_memory_thread` died.
|
||
|
#
|
||
|
# b. A process won't hang when putting into a queue;
|
||
|
#
|
||
|
# We use `mp.Queue` which has a separate background thread to put
|
||
|
# objects from an unbounded buffer array. The background thread is
|
||
|
# daemonic and usually automatically joined when the process
|
||
|
# *exits*.
|
||
|
#
|
||
|
# In case that the receiver has ended abruptly while
|
||
|
# reading from the pipe, the join will hang forever. The usual
|
||
|
# solution for this in Python is calling `q.cancel_join_thread`,
|
||
|
# which prevents automatically joining it when finalizing
|
||
|
# (exiting).
|
||
|
#
|
||
|
# Nonetheless, `cancel_join_thread` must only be called when the
|
||
|
# queue is **not** going to be read from or write into by another
|
||
|
# process, because it may hold onto a lock or leave corrupted data
|
||
|
# in the queue, leading other readers/writers to hang.
|
||
|
#
|
||
|
# Hence,
|
||
|
# + For worker processes, we only do so (for their output
|
||
|
# queues, i.e., `worker_result_queue`) before exiting.
|
||
|
# + For `pin_memory_thread`, its output queue `data_queue` is a
|
||
|
# `queue.Queue` that does blocking `put` if the queue is full.
|
||
|
# So there is no above problem, but as a result, in
|
||
|
# `_pin_memory_loop`, we do need to wrap the `put` in a loop
|
||
|
# that breaks not only upon success, but also when the main
|
||
|
# process stops reading, i.e., is shutting down.
|
||
|
# + For loader process, we `cancel_join_thread()` for all
|
||
|
# `_index_queues` because the whole purpose of workers and
|
||
|
# `pin_memory_thread` is to serve the loader process. If
|
||
|
# loader process is already exiting, we don't really care if
|
||
|
# the queues are corrupted.
|
||
|
#
|
||
|
#
|
||
|
# Now let's get back to 1:
|
||
|
# how we gracefully exit the workers when the last reference to the
|
||
|
# iterator is gone.
|
||
|
#
|
||
|
# To achieve this, we implement the following logic along with the design
|
||
|
# choices mentioned above:
|
||
|
#
|
||
|
# `workers_done_event`:
|
||
|
# A `multiprocessing.Event` shared among the main process and all worker
|
||
|
# processes. This is used to signal the workers that the iterator is
|
||
|
# shutting down. After it is set, they will not send processed data to
|
||
|
# queues anymore, and only wait for the final `None` before exiting.
|
||
|
# `done_event` isn't strictly needed. I.e., we can just check for `None`
|
||
|
# from the input queue, but it allows us to skip wasting resources
|
||
|
# processing data if we are already shutting down.
|
||
|
#
|
||
|
# `pin_memory_thread_done_event`:
|
||
|
# A `threading.Event` for a similar purpose to that of
|
||
|
# `workers_done_event`, but is for the `pin_memory_thread`. The reason
|
||
|
# that separate events are needed is that `pin_memory_thread` reads from
|
||
|
# the output queue of the workers. But the workers, upon seeing that
|
||
|
# `workers_done_event` is set, only wants to see the final `None`, and is
|
||
|
# not required to flush all data in the output queue (e.g., it may call
|
||
|
# `cancel_join_thread` on that queue if its `IterableDataset` iterator
|
||
|
# happens to exhaust coincidentally, which is out of the control of the
|
||
|
# main process). Thus, since we will exit `pin_memory_thread` before the
|
||
|
# workers (see below), two separete events are used.
|
||
|
#
|
||
|
# NOTE: In short, the protocol is that the main process will set these
|
||
|
# `done_event`s and then the corresponding processes/threads a `None`,
|
||
|
# and that they may exit at any time after receiving the `None`.
|
||
|
#
|
||
|
# NOTE: Using `None` as the final signal is valid, since normal data will
|
||
|
# always be a 2-tuple with the 1st element being the index of the data
|
||
|
# transferred (different from dataset index/key), and the 2nd being
|
||
|
# either the dataset key or the data sample (depending on which part
|
||
|
# of the data model the queue is at).
|
||
|
#
|
||
|
# [ worker processes ]
|
||
|
# While loader process is alive:
|
||
|
# Get from `index_queue`.
|
||
|
# If get anything else,
|
||
|
# Check `workers_done_event`.
|
||
|
# If set, continue to next iteration
|
||
|
# i.e., keep getting until see the `None`, then exit.
|
||
|
# Otherwise, process data:
|
||
|
# If is fetching from an `IterableDataset` and the iterator
|
||
|
# is exhausted, send an `_IterableDatasetStopIteration`
|
||
|
# object to signal iteration end. The main process, upon
|
||
|
# receiving such an object, will send `None` to this
|
||
|
# worker and not use the corresponding `index_queue`
|
||
|
# anymore.
|
||
|
# If timed out,
|
||
|
# No matter `workers_done_event` is set (still need to see `None`)
|
||
|
# or not, must continue to next iteration.
|
||
|
# (outside loop)
|
||
|
# If `workers_done_event` is set, (this can be False with `IterableDataset`)
|
||
|
# `data_queue.cancel_join_thread()`. (Everything is ending here:
|
||
|
# main process won't read from it;
|
||
|
# other workers will also call
|
||
|
# `cancel_join_thread`.)
|
||
|
#
|
||
|
# [ pin_memory_thread ]
|
||
|
# # No need to check main thread. If this thread is alive, the main loader
|
||
|
# # thread must be alive, because this thread is set as daemonic.
|
||
|
# While `pin_memory_thread_done_event` is not set:
|
||
|
# Get from `worker_result_queue`.
|
||
|
# If timed out, continue to get in the next iteration.
|
||
|
# Otherwise, process data.
|
||
|
# While `pin_memory_thread_done_event` is not set:
|
||
|
# Put processed data to `data_queue` (a `queue.Queue` with blocking put)
|
||
|
# If timed out, continue to put in the next iteration.
|
||
|
# Otherwise, break, i.e., continuing to the out loop.
|
||
|
#
|
||
|
# NOTE: we don't check the status of the main thread because
|
||
|
# 1. if the process is killed by fatal signal, `pin_memory_thread`
|
||
|
# ends.
|
||
|
# 2. in other cases, either the cleaning-up in __del__ or the
|
||
|
# automatic exit of daemonic thread will take care of it.
|
||
|
# This won't busy-wait either because `.get(timeout)` does not
|
||
|
# busy-wait.
|
||
|
#
|
||
|
# [ main process ]
|
||
|
# In the DataLoader Iter's `__del__`
|
||
|
# b. Exit `pin_memory_thread`
|
||
|
# i. Set `pin_memory_thread_done_event`.
|
||
|
# ii Put `None` in `worker_result_queue`.
|
||
|
# iii. Join the `pin_memory_thread`.
|
||
|
# iv. `worker_result_queue.cancel_join_thread()`.
|
||
|
#
|
||
|
# c. Exit the workers.
|
||
|
# i. Set `workers_done_event`.
|
||
|
# ii. Put `None` in each worker's `index_queue`.
|
||
|
# iii. Join the workers.
|
||
|
# iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
|
||
|
#
|
||
|
# NOTE: (c) is better placed after (b) because it may leave corrupted
|
||
|
# data in `worker_result_queue`, which `pin_memory_thread`
|
||
|
# reads from, in which case the `pin_memory_thread` can only
|
||
|
# happen at timing out, which is slow. Nonetheless, same thing
|
||
|
# happens if a worker is killed by signal at unfortunate times,
|
||
|
# but in other cases, we are better off having a non-corrupted
|
||
|
# `worker_result_queue` for `pin_memory_thread`.
|
||
|
#
|
||
|
# NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
|
||
|
# can be omitted
|
||
|
#
|
||
|
# NB: `done_event`s isn't strictly needed. E.g., we can just check for
|
||
|
# `None` from `index_queue`, but it allows us to skip wasting resources
|
||
|
# processing indices already in `index_queue` if we are already shutting
|
||
|
# down.
|
||
|
|
||
|
def __init__(self, loader):
|
||
|
super().__init__(loader)
|
||
|
|
||
|
self._prefetch_factor = loader.prefetch_factor
|
||
|
|
||
|
assert self._num_workers > 0
|
||
|
assert self._prefetch_factor > 0
|
||
|
|
||
|
if loader.multiprocessing_context is None:
|
||
|
multiprocessing_context = multiprocessing
|
||
|
else:
|
||
|
multiprocessing_context = loader.multiprocessing_context
|
||
|
|
||
|
self._worker_init_fn = loader.worker_init_fn
|
||
|
|
||
|
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
||
|
# Additional worker init function will take care of sharding in MP and Distributed
|
||
|
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
|
||
|
self._worker_init_fn = functools.partial(
|
||
|
_sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank)
|
||
|
|
||
|
# No certainty which module multiprocessing_context is
|
||
|
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
||
|
self._worker_pids_set = False
|
||
|
self._shutdown = False
|
||
|
self._workers_done_event = multiprocessing_context.Event()
|
||
|
|
||
|
self._index_queues = []
|
||
|
self._workers = []
|
||
|
for i in range(self._num_workers):
|
||
|
# No certainty which module multiprocessing_context is
|
||
|
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
||
|
# Need to `cancel_join_thread` here!
|
||
|
# See sections (2) and (3b) above.
|
||
|
index_queue.cancel_join_thread()
|
||
|
w = multiprocessing_context.Process(
|
||
|
target=_utils.worker._worker_loop,
|
||
|
args=(self._dataset_kind, self._dataset, index_queue,
|
||
|
self._worker_result_queue, self._workers_done_event,
|
||
|
self._auto_collation, self._collate_fn, self._drop_last,
|
||
|
self._base_seed, self._worker_init_fn, i, self._num_workers,
|
||
|
self._persistent_workers, self._shared_seed))
|
||
|
w.daemon = True
|
||
|
# NB: Process.start() actually take some time as it needs to
|
||
|
# start a process and pass the arguments over via a pipe.
|
||
|
# Therefore, we only add a worker to self._workers list after
|
||
|
# it started, so that we do not call .join() if program dies
|
||
|
# before it starts, and __del__ tries to join but will get:
|
||
|
# AssertionError: can only join a started process.
|
||
|
w.start()
|
||
|
self._index_queues.append(index_queue)
|
||
|
self._workers.append(w)
|
||
|
|
||
|
if self._pin_memory:
|
||
|
self._pin_memory_thread_done_event = threading.Event()
|
||
|
|
||
|
# Queue is not type-annotated
|
||
|
self._data_queue = queue.Queue() # type: ignore[var-annotated]
|
||
|
if self._pin_memory_device == "xpu":
|
||
|
current_device = torch.xpu.current_device() # type: ignore[attr-defined]
|
||
|
elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
|
||
|
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
|
||
|
current_device = custom_device_mod.current_device()
|
||
|
else:
|
||
|
current_device = torch.cuda.current_device() # choose cuda for default
|
||
|
pin_memory_thread = threading.Thread(
|
||
|
target=_utils.pin_memory._pin_memory_loop,
|
||
|
args=(self._worker_result_queue, self._data_queue,
|
||
|
current_device,
|
||
|
self._pin_memory_thread_done_event, self._pin_memory_device))
|
||
|
pin_memory_thread.daemon = True
|
||
|
pin_memory_thread.start()
|
||
|
# Similar to workers (see comment above), we only register
|
||
|
# pin_memory_thread once it is started.
|
||
|
self._pin_memory_thread = pin_memory_thread
|
||
|
else:
|
||
|
self._data_queue = self._worker_result_queue # type: ignore[assignment]
|
||
|
|
||
|
# In some rare cases, persistent workers (daemonic processes)
|
||
|
# would be terminated before `__del__` of iterator is invoked
|
||
|
# when main process exits
|
||
|
# It would cause failure when pin_memory_thread tries to read
|
||
|
# corrupted data from worker_result_queue
|
||
|
# atexit is used to shutdown thread and child processes in the
|
||
|
# right sequence before main process exits
|
||
|
if self._persistent_workers and self._pin_memory:
|
||
|
import atexit
|
||
|
for w in self._workers:
|
||
|
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
|
||
|
|
||
|
# .pid can be None only before process is spawned (not the case, so ignore)
|
||
|
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
|
||
|
_utils.signal_handling._set_SIGCHLD_handler()
|
||
|
self._worker_pids_set = True
|
||
|
self._reset(loader, first_iter=True)
|
||
|
|
||
|
def _reset(self, loader, first_iter=False):
|
||
|
super()._reset(loader, first_iter)
|
||
|
self._send_idx = 0 # idx of the next task to be sent to workers
|
||
|
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
|
||
|
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
|
||
|
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
|
||
|
# \ (worker_id, data) if data is already fetched (out-of-order)
|
||
|
self._task_info = {}
|
||
|
self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
|
||
|
# A list of booleans representing whether each worker still has work to
|
||
|
# do, i.e., not having exhausted its iterable dataset object. It always
|
||
|
# contains all `True`s if not using an iterable-style dataset
|
||
|
# (i.e., if kind != Iterable).
|
||
|
# Not that this indicates that a worker still has work to do *for this epoch*.
|
||
|
# It does not mean that a worker is dead. In case of `_persistent_workers`,
|
||
|
# the worker will be reset to available in the next epoch.
|
||
|
self._workers_status = [True for i in range(self._num_workers)]
|
||
|
# Reset the worker queue cycle so it resumes next epoch at worker 0
|
||
|
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
|
||
|
# We resume the prefetching in case it was enabled
|
||
|
if not first_iter:
|
||
|
for idx in range(self._num_workers):
|
||
|
self._index_queues[idx].put(_utils.worker._ResumeIteration(self._shared_seed))
|
||
|
resume_iteration_cnt = self._num_workers
|
||
|
while resume_iteration_cnt > 0:
|
||
|
return_idx, return_data = self._get_data()
|
||
|
if isinstance(return_idx, _utils.worker._ResumeIteration):
|
||
|
assert return_data is None
|
||
|
resume_iteration_cnt -= 1
|
||
|
# prime the prefetch loop
|
||
|
for _ in range(self._prefetch_factor * self._num_workers):
|
||
|
self._try_put_index()
|
||
|
|
||
|
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
|
||
|
# Tries to fetch data from `self._data_queue` once for a given timeout.
|
||
|
# This can also be used as inner loop of fetching without timeout, with
|
||
|
# the sender status as the loop condition.
|
||
|
#
|
||
|
# This raises a `RuntimeError` if any worker died expectedly. This error
|
||
|
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
|
||
|
# (only for non-Windows platforms), or the manual check below on errors
|
||
|
# and timeouts.
|
||
|
#
|
||
|
# Returns a 2-tuple:
|
||
|
# (bool: whether successfully get data, any: data if successful else None)
|
||
|
try:
|
||
|
data = self._data_queue.get(timeout=timeout)
|
||
|
return (True, data)
|
||
|
except Exception as e:
|
||
|
# At timeout and error, we manually check whether any worker has
|
||
|
# failed. Note that this is the only mechanism for Windows to detect
|
||
|
# worker failures.
|
||
|
failed_workers = []
|
||
|
for worker_id, w in enumerate(self._workers):
|
||
|
if self._workers_status[worker_id] and not w.is_alive():
|
||
|
failed_workers.append(w)
|
||
|
self._mark_worker_as_unavailable(worker_id)
|
||
|
if len(failed_workers) > 0:
|
||
|
pids_str = ', '.join(str(w.pid) for w in failed_workers)
|
||
|
raise RuntimeError(f'DataLoader worker (pid(s) {pids_str}) exited unexpectedly') from e
|
||
|
if isinstance(e, queue.Empty):
|
||
|
return (False, None)
|
||
|
import tempfile
|
||
|
import errno
|
||
|
try:
|
||
|
# Raise an exception if we are this close to the FDs limit.
|
||
|
# Apparently, trying to open only one file is not a sufficient
|
||
|
# test.
|
||
|
# See NOTE [ DataLoader on Linux and open files limit ]
|
||
|
fds_limit_margin = 10
|
||
|
fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
|
||
|
except OSError as e:
|
||
|
if e.errno == errno.EMFILE:
|
||
|
raise RuntimeError(
|
||
|
"Too many open files. Communication with the"
|
||
|
" workers is no longer possible. Please increase the"
|
||
|
" limit using `ulimit -n` in the shell or change the"
|
||
|
" sharing strategy by calling"
|
||
|
" `torch.multiprocessing.set_sharing_strategy('file_system')`"
|
||
|
" at the beginning of your code") from None
|
||
|
raise
|
||
|
|
||
|
# NOTE [ DataLoader on Linux and open files limit ]
|
||
|
#
|
||
|
# On Linux when DataLoader is used with multiprocessing we pass the data between
|
||
|
# the root process and the workers through SHM files. We remove those files from
|
||
|
# the filesystem as soon as they are created and keep them alive by
|
||
|
# passing around their file descriptors through AF_UNIX sockets. (See
|
||
|
# docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
|
||
|
# the wiki (https://github.com/pytorch/pytorch/wiki).)
|
||
|
#
|
||
|
# This sometimes leads us to exceeding the open files limit. When that happens,
|
||
|
# and the offending file descriptor is coming over a socket, the `socket` Python
|
||
|
# package silently strips the file descriptor from the message, setting only the
|
||
|
# `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
|
||
|
# it _indicates that some control data were discarded due to lack of space in
|
||
|
# the buffer for ancillary data_). This might reflect the C implementation of
|
||
|
# AF_UNIX sockets.
|
||
|
#
|
||
|
# This behaviour can be reproduced with the script and instructions at the
|
||
|
# bottom of this note.
|
||
|
#
|
||
|
# When that happens, the standard Python `multiprocessing` (and not
|
||
|
# `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
|
||
|
#
|
||
|
# Sometimes, instead of the FD being stripped, you may get an `OSError:
|
||
|
# Too many open files`, both in the script below and in DataLoader. However,
|
||
|
# this is rare and seems to be nondeterministic.
|
||
|
#
|
||
|
#
|
||
|
# #!/usr/bin/env python3
|
||
|
# import sys
|
||
|
# import socket
|
||
|
# import os
|
||
|
# import array
|
||
|
# import shutil
|
||
|
# import socket
|
||
|
#
|
||
|
#
|
||
|
# if len(sys.argv) != 4:
|
||
|
# print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
|
||
|
# sys.exit(1)
|
||
|
#
|
||
|
# if __name__ == '__main__':
|
||
|
# dirname = sys.argv[1]
|
||
|
# sock_path = dirname + "/sock"
|
||
|
# iterations = int(sys.argv[2])
|
||
|
# def dummy_path(i):
|
||
|
# return dirname + "/" + str(i) + ".dummy"
|
||
|
#
|
||
|
#
|
||
|
# if sys.argv[3] == 'send':
|
||
|
# while not os.path.exists(sock_path):
|
||
|
# pass
|
||
|
# client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
|
||
|
# client.connect(sock_path)
|
||
|
# for i in range(iterations):
|
||
|
# fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
|
||
|
# ancdata = array.array('i', [fd])
|
||
|
# msg = bytes([i % 256])
|
||
|
# print("Sending fd ", fd, " (iteration #", i, ")")
|
||
|
# client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
|
||
|
#
|
||
|
#
|
||
|
# else:
|
||
|
# assert sys.argv[3] == 'recv'
|
||
|
#
|
||
|
# if os.path.exists(dirname):
|
||
|
# raise Exception("Directory exists")
|
||
|
#
|
||
|
# os.mkdir(dirname)
|
||
|
#
|
||
|
# print("Opening socket...")
|
||
|
# server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
|
||
|
# server.bind(sock_path)
|
||
|
#
|
||
|
# print("Listening...")
|
||
|
# for i in range(iterations):
|
||
|
# a = array.array('i')
|
||
|
# msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
|
||
|
# assert(len(ancdata) == 1)
|
||
|
# cmsg_level, cmsg_type, cmsg_data = ancdata[0]
|
||
|
# a.frombytes(cmsg_data)
|
||
|
# print("Received fd ", a[0], " (iteration #", i, ")")
|
||
|
#
|
||
|
# shutil.rmtree(dirname)
|
||
|
#
|
||
|
# Steps to reproduce:
|
||
|
#
|
||
|
# 1. Run two shells and set lower file descriptor limit in the receiving one:
|
||
|
# (shell1) ulimit -n 1020
|
||
|
# (shell2) ulimit -n 1022
|
||
|
#
|
||
|
# 2. Run the script above with the `recv` option in the first shell
|
||
|
# (shell1) ./test_socket.py sock_tmp 1017 recv
|
||
|
#
|
||
|
# 3. Run the script with the `send` option in the second shell:
|
||
|
# (shell2) ./test_socket.py sock_tmp 1017 send
|
||
|
|
||
|
def _get_data(self):
|
||
|
# Fetches data from `self._data_queue`.
|
||
|
#
|
||
|
# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
|
||
|
# which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
|
||
|
# in a loop. This is the only mechanism to detect worker failures for
|
||
|
# Windows. For other platforms, a SIGCHLD handler is also used for
|
||
|
# worker failure detection.
|
||
|
#
|
||
|
# If `pin_memory=True`, we also need check if `pin_memory_thread` had
|
||
|
# died at timeouts.
|
||
|
if self._timeout > 0:
|
||
|
success, data = self._try_get_data(self._timeout)
|
||
|
if success:
|
||
|
return data
|
||
|
else:
|
||
|
raise RuntimeError(f'DataLoader timed out after {self._timeout} seconds')
|
||
|
elif self._pin_memory:
|
||
|
while self._pin_memory_thread.is_alive():
|
||
|
success, data = self._try_get_data()
|
||
|
if success:
|
||
|
return data
|
||
|
else:
|
||
|
# while condition is false, i.e., pin_memory_thread died.
|
||
|
raise RuntimeError('Pin memory thread exited unexpectedly')
|
||
|
# In this case, `self._data_queue` is a `queue.Queue`,. But we don't
|
||
|
# need to call `.task_done()` because we don't use `.join()`.
|
||
|
else:
|
||
|
while True:
|
||
|
success, data = self._try_get_data()
|
||
|
if success:
|
||
|
return data
|
||
|
|
||
|
def _next_data(self):
|
||
|
while True:
|
||
|
# If the worker responsible for `self._rcvd_idx` has already ended
|
||
|
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
|
||
|
# we try to advance `self._rcvd_idx` to find the next valid index.
|
||
|
#
|
||
|
# This part needs to run in the loop because both the `self._get_data()`
|
||
|
# call and `_IterableDatasetStopIteration` check below can mark
|
||
|
# extra worker(s) as dead.
|
||
|
while self._rcvd_idx < self._send_idx:
|
||
|
info = self._task_info[self._rcvd_idx]
|
||
|
worker_id = info[0]
|
||
|
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
|
||
|
break
|
||
|
del self._task_info[self._rcvd_idx]
|
||
|
self._rcvd_idx += 1
|
||
|
else:
|
||
|
# no valid `self._rcvd_idx` is found (i.e., didn't break)
|
||
|
if not self._persistent_workers:
|
||
|
self._shutdown_workers()
|
||
|
raise StopIteration
|
||
|
|
||
|
# Now `self._rcvd_idx` is the batch index we want to fetch
|
||
|
|
||
|
# Check if the next sample has already been generated
|
||
|
if len(self._task_info[self._rcvd_idx]) == 2:
|
||
|
data = self._task_info.pop(self._rcvd_idx)[1]
|
||
|
return self._process_data(data)
|
||
|
|
||
|
assert not self._shutdown and self._tasks_outstanding > 0
|
||
|
idx, data = self._get_data()
|
||
|
self._tasks_outstanding -= 1
|
||
|
if self._dataset_kind == _DatasetKind.Iterable:
|
||
|
# Check for _IterableDatasetStopIteration
|
||
|
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
|
||
|
if self._persistent_workers:
|
||
|
self._workers_status[data.worker_id] = False
|
||
|
else:
|
||
|
self._mark_worker_as_unavailable(data.worker_id)
|
||
|
self._try_put_index()
|
||
|
continue
|
||
|
|
||
|
if idx != self._rcvd_idx:
|
||
|
# store out-of-order samples
|
||
|
self._task_info[idx] += (data,)
|
||
|
else:
|
||
|
del self._task_info[idx]
|
||
|
return self._process_data(data)
|
||
|
|
||
|
def _try_put_index(self):
|
||
|
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
|
||
|
|
||
|
try:
|
||
|
index = self._next_index()
|
||
|
except StopIteration:
|
||
|
return
|
||
|
for _ in range(self._num_workers): # find the next active worker, if any
|
||
|
worker_queue_idx = next(self._worker_queue_idx_cycle)
|
||
|
if self._workers_status[worker_queue_idx]:
|
||
|
break
|
||
|
else:
|
||
|
# not found (i.e., didn't break)
|
||
|
return
|
||
|
|
||
|
self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined]
|
||
|
self._task_info[self._send_idx] = (worker_queue_idx,)
|
||
|
self._tasks_outstanding += 1
|
||
|
self._send_idx += 1
|
||
|
|
||
|
def _process_data(self, data):
|
||
|
self._rcvd_idx += 1
|
||
|
self._try_put_index()
|
||
|
if isinstance(data, ExceptionWrapper):
|
||
|
data.reraise()
|
||
|
return data
|
||
|
|
||
|
def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
|
||
|
# Mark a worker as having finished its work e.g., due to
|
||
|
# exhausting an `IterableDataset`. This should be used only when this
|
||
|
# `_MultiProcessingDataLoaderIter` is going to continue running.
|
||
|
|
||
|
assert self._workers_status[worker_id] or (self._persistent_workers and shutdown)
|
||
|
|
||
|
# Signal termination to that specific worker.
|
||
|
q = self._index_queues[worker_id]
|
||
|
# Indicate that no more data will be put on this queue by the current
|
||
|
# process.
|
||
|
q.put(None)
|
||
|
|
||
|
# Note that we don't actually join the worker here, nor do we remove the
|
||
|
# worker's pid from C side struct because (1) joining may be slow, and
|
||
|
# (2) since we don't join, the worker may still raise error, and we
|
||
|
# prefer capturing those, rather than ignoring them, even though they
|
||
|
# are raised after the worker has finished its job.
|
||
|
# Joinning is deferred to `_shutdown_workers`, which it is called when
|
||
|
# all workers finish their jobs (e.g., `IterableDataset` replicas) or
|
||
|
# when this iterator is garbage collected.
|
||
|
|
||
|
self._workers_status[worker_id] = False
|
||
|
|
||
|
assert self._workers_done_event.is_set() == shutdown
|
||
|
|
||
|
def _shutdown_workers(self):
|
||
|
# Called when shutting down this `_MultiProcessingDataLoaderIter`.
|
||
|
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
|
||
|
# the logic of this function.
|
||
|
if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None:
|
||
|
# See (2) of the note. If Python is shutting down, do no-op.
|
||
|
return
|
||
|
# Normal exit when last reference is gone / iterator is depleted.
|
||
|
# See (1) and the second half of the note.
|
||
|
if not self._shutdown:
|
||
|
self._shutdown = True
|
||
|
try:
|
||
|
# Normal exit when last reference is gone / iterator is depleted.
|
||
|
# See (1) and the second half of the note.
|
||
|
|
||
|
# Exit `pin_memory_thread` first because exiting workers may leave
|
||
|
# corrupted data in `worker_result_queue` which `pin_memory_thread`
|
||
|
# reads from.
|
||
|
if hasattr(self, '_pin_memory_thread'):
|
||
|
# Use hasattr in case error happens before we set the attribute.
|
||
|
self._pin_memory_thread_done_event.set()
|
||
|
# Send something to pin_memory_thread in case it is waiting
|
||
|
# so that it can wake up and check `pin_memory_thread_done_event`
|
||
|
self._worker_result_queue.put((None, None))
|
||
|
self._pin_memory_thread.join()
|
||
|
self._worker_result_queue.cancel_join_thread()
|
||
|
self._worker_result_queue.close()
|
||
|
|
||
|
# Exit workers now.
|
||
|
self._workers_done_event.set()
|
||
|
for worker_id in range(len(self._workers)):
|
||
|
# Get number of workers from `len(self._workers)` instead of
|
||
|
# `self._num_workers` in case we error before starting all
|
||
|
# workers.
|
||
|
# If we are using workers_status with persistent_workers
|
||
|
# we have to shut it down because the worker is paused
|
||
|
if self._persistent_workers or self._workers_status[worker_id]:
|
||
|
self._mark_worker_as_unavailable(worker_id, shutdown=True)
|
||
|
for w in self._workers:
|
||
|
# We should be able to join here, but in case anything went
|
||
|
# wrong, we set a timeout and if the workers fail to join,
|
||
|
# they are killed in the `finally` block.
|
||
|
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
|
||
|
for q in self._index_queues:
|
||
|
q.cancel_join_thread()
|
||
|
q.close()
|
||
|
finally:
|
||
|
# Even though all this function does is putting into queues that
|
||
|
# we have called `cancel_join_thread` on, weird things can
|
||
|
# happen when a worker is killed by a signal, e.g., hanging in
|
||
|
# `Event.set()`. So we need to guard this with SIGCHLD handler,
|
||
|
# and remove pids from the C side data structure only at the
|
||
|
# end.
|
||
|
#
|
||
|
# FIXME: Unfortunately, for Windows, we are missing a worker
|
||
|
# error detection mechanism here in this function, as it
|
||
|
# doesn't provide a SIGCHLD handler.
|
||
|
if self._worker_pids_set:
|
||
|
_utils.signal_handling._remove_worker_pids(id(self))
|
||
|
self._worker_pids_set = False
|
||
|
for w in self._workers:
|
||
|
if w.is_alive():
|
||
|
# Existing mechanisms try to make the workers exit
|
||
|
# peacefully, but in case that we unfortunately reach
|
||
|
# here, which we shouldn't, (e.g., pytorch/pytorch#39570),
|
||
|
# we kill the worker.
|
||
|
w.terminate()
|
||
|
|
||
|
# staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
|
||
|
@staticmethod
|
||
|
def _clean_up_worker(w):
|
||
|
try:
|
||
|
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
|
||
|
finally:
|
||
|
if w.is_alive():
|
||
|
w.terminate()
|
||
|
|
||
|
def __del__(self):
|
||
|
self._shutdown_workers()
|