150 lines
5.7 KiB
Python
150 lines
5.7 KiB
Python
|
import io
|
||
|
import pickle
|
||
|
import warnings
|
||
|
|
||
|
from collections.abc import Collection
|
||
|
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||
|
|
||
|
from torch.utils.data import IterDataPipe, MapDataPipe
|
||
|
from torch.utils._import_utils import dill_available
|
||
|
|
||
|
|
||
|
__all__ = ["traverse", "traverse_dps"]
|
||
|
|
||
|
DataPipe = Union[IterDataPipe, MapDataPipe]
|
||
|
DataPipeGraph = Dict[int, Tuple[DataPipe, "DataPipeGraph"]] # type: ignore[misc]
|
||
|
|
||
|
|
||
|
def _stub_unpickler():
|
||
|
return "STUB"
|
||
|
|
||
|
|
||
|
# TODO(VitalyFedyunin): Make sure it works without dill module installed
|
||
|
def _list_connected_datapipes(scan_obj: DataPipe, only_datapipe: bool, cache: Set[int]) -> List[DataPipe]:
|
||
|
f = io.BytesIO()
|
||
|
p = pickle.Pickler(f) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is
|
||
|
if dill_available():
|
||
|
from dill import Pickler as dill_Pickler
|
||
|
d = dill_Pickler(f)
|
||
|
else:
|
||
|
d = None
|
||
|
|
||
|
captured_connections = []
|
||
|
|
||
|
def getstate_hook(ori_state):
|
||
|
state = None
|
||
|
if isinstance(ori_state, dict):
|
||
|
state = {} # type: ignore[assignment]
|
||
|
for k, v in ori_state.items():
|
||
|
if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
|
||
|
state[k] = v # type: ignore[attr-defined]
|
||
|
elif isinstance(ori_state, (tuple, list)):
|
||
|
state = [] # type: ignore[assignment]
|
||
|
for v in ori_state:
|
||
|
if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
|
||
|
state.append(v) # type: ignore[attr-defined]
|
||
|
elif isinstance(ori_state, (IterDataPipe, MapDataPipe, Collection)):
|
||
|
state = ori_state # type: ignore[assignment]
|
||
|
return state
|
||
|
|
||
|
def reduce_hook(obj):
|
||
|
if obj == scan_obj or id(obj) in cache:
|
||
|
raise NotImplementedError
|
||
|
else:
|
||
|
captured_connections.append(obj)
|
||
|
# Adding id to remove duplicate DataPipe serialized at the same level
|
||
|
cache.add(id(obj))
|
||
|
return _stub_unpickler, ()
|
||
|
|
||
|
datapipe_classes: Tuple[Type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment]
|
||
|
|
||
|
try:
|
||
|
for cls in datapipe_classes:
|
||
|
cls.set_reduce_ex_hook(reduce_hook)
|
||
|
if only_datapipe:
|
||
|
cls.set_getstate_hook(getstate_hook)
|
||
|
try:
|
||
|
p.dump(scan_obj)
|
||
|
except (pickle.PickleError, AttributeError, TypeError):
|
||
|
if dill_available():
|
||
|
d.dump(scan_obj)
|
||
|
else:
|
||
|
raise
|
||
|
finally:
|
||
|
for cls in datapipe_classes:
|
||
|
cls.set_reduce_ex_hook(None)
|
||
|
if only_datapipe:
|
||
|
cls.set_getstate_hook(None)
|
||
|
if dill_available():
|
||
|
from dill import extend as dill_extend
|
||
|
dill_extend(False) # Undo change to dispatch table
|
||
|
return captured_connections
|
||
|
|
||
|
|
||
|
def traverse_dps(datapipe: DataPipe) -> DataPipeGraph:
|
||
|
r"""
|
||
|
Traverse the DataPipes and their attributes to extract the DataPipe graph.
|
||
|
|
||
|
This only looks into the attribute from each DataPipe that is either a
|
||
|
DataPipe and a Python collection object such as ``list``, ``tuple``,
|
||
|
``set`` and ``dict``.
|
||
|
|
||
|
Args:
|
||
|
datapipe: the end DataPipe of the graph
|
||
|
Returns:
|
||
|
A graph represented as a nested dictionary, where keys are ids of DataPipe instances
|
||
|
and values are tuples of DataPipe instance and the sub-graph
|
||
|
"""
|
||
|
cache: Set[int] = set()
|
||
|
return _traverse_helper(datapipe, only_datapipe=True, cache=cache)
|
||
|
|
||
|
|
||
|
def traverse(datapipe: DataPipe, only_datapipe: Optional[bool] = None) -> DataPipeGraph:
|
||
|
r"""
|
||
|
Traverse the DataPipes and their attributes to extract the DataPipe graph.
|
||
|
|
||
|
[Deprecated]
|
||
|
When ``only_dataPipe`` is specified as ``True``, it would only look into the
|
||
|
attribute from each DataPipe that is either a DataPipe and a Python collection object
|
||
|
such as ``list``, ``tuple``, ``set`` and ``dict``.
|
||
|
|
||
|
Note:
|
||
|
This function is deprecated. Please use `traverse_dps` instead.
|
||
|
|
||
|
Args:
|
||
|
datapipe: the end DataPipe of the graph
|
||
|
only_datapipe: If ``False`` (default), all attributes of each DataPipe are traversed.
|
||
|
This argument is deprecating and will be removed after the next release.
|
||
|
Returns:
|
||
|
A graph represented as a nested dictionary, where keys are ids of DataPipe instances
|
||
|
and values are tuples of DataPipe instance and the sub-graph
|
||
|
"""
|
||
|
msg = "`traverse` function and will be removed after 1.13. " \
|
||
|
"Please use `traverse_dps` instead."
|
||
|
if not only_datapipe:
|
||
|
msg += " And, the behavior will be changed to the equivalent of `only_datapipe=True`."
|
||
|
warnings.warn(msg, FutureWarning)
|
||
|
if only_datapipe is None:
|
||
|
only_datapipe = False
|
||
|
cache: Set[int] = set()
|
||
|
return _traverse_helper(datapipe, only_datapipe, cache)
|
||
|
|
||
|
|
||
|
# Add cache here to prevent infinite recursion on DataPipe
|
||
|
def _traverse_helper(datapipe: DataPipe, only_datapipe: bool, cache: Set[int]) -> DataPipeGraph:
|
||
|
if not isinstance(datapipe, (IterDataPipe, MapDataPipe)):
|
||
|
raise RuntimeError(f"Expected `IterDataPipe` or `MapDataPipe`, but {type(datapipe)} is found")
|
||
|
|
||
|
dp_id = id(datapipe)
|
||
|
if dp_id in cache:
|
||
|
return {}
|
||
|
cache.add(dp_id)
|
||
|
# Using cache.copy() here is to prevent the same DataPipe pollutes the cache on different paths
|
||
|
items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy())
|
||
|
d: DataPipeGraph = {dp_id: (datapipe, {})}
|
||
|
for item in items:
|
||
|
# Using cache.copy() here is to prevent recursion on a single path rather than global graph
|
||
|
# Single DataPipe can present multiple times in different paths in graph
|
||
|
d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
|
||
|
return d
|