282 lines
11 KiB
Python
282 lines
11 KiB
Python
import collections
|
|
import copyreg
|
|
import io
|
|
import pickle
|
|
import sys
|
|
import threading
|
|
import traceback
|
|
from enum import Enum
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch._C._distributed_rpc import _get_current_rpc_agent
|
|
|
|
__all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"]
|
|
|
|
# Thread local tensor tables to store tensors while pickling torch.Tensor
|
|
# objects
|
|
_thread_local_tensor_tables = threading.local()
|
|
_pickler = pickle.Pickler
|
|
_unpickler = pickle.Unpickler
|
|
|
|
|
|
class RPCExecMode(Enum):
|
|
SYNC = "sync"
|
|
ASYNC = "async"
|
|
ASYNC_JIT = "async_jit"
|
|
REMOTE = "remote"
|
|
|
|
|
|
class _InternalRPCPickler:
|
|
r"""
|
|
This class provides serialize() and deserialize() interfaces to serialize
|
|
data to be "binary string + tensor table" format
|
|
So for RPC python UDF function and args, non tensor data will be serialized
|
|
into regular binary string, tensor data will be put into thread local tensor
|
|
tables, this serialization format is consistent with builtin operator and args
|
|
using JIT pickler. This format will make tensor handling in C++ much easier,
|
|
e.g. attach tensor to distributed autograd graph in C++
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Ignore type error because dispatch_table is defined in third-party package
|
|
self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined]
|
|
self._dispatch_table[torch.Tensor] = self._tensor_reducer
|
|
# Used for registering customized picklers.
|
|
self._class_reducer_dict = {}
|
|
|
|
def _register_reducer(self, obj_class, reducer):
|
|
# For the same class, only register the reducer once.
|
|
if obj_class not in self._class_reducer_dict:
|
|
self._class_reducer_dict[obj_class] = reducer
|
|
|
|
@classmethod
|
|
def _tensor_receiver(cls, tensor_index):
|
|
global _thread_local_tensor_tables
|
|
return _thread_local_tensor_tables.recv_tables[tensor_index]
|
|
|
|
def _tensor_reducer(self, tensor):
|
|
global _thread_local_tensor_tables
|
|
_thread_local_tensor_tables.send_tables.append(tensor)
|
|
tensor_index = len(_thread_local_tensor_tables.send_tables) - 1
|
|
return (_InternalRPCPickler._tensor_receiver, (tensor_index,))
|
|
|
|
@classmethod
|
|
def _py_rref_receiver(cls, rref_fork_data):
|
|
return dist.rpc.PyRRef._deserialize(rref_fork_data)
|
|
|
|
def _py_rref_reducer(self, py_rref):
|
|
rref_fork_data = py_rref._serialize()
|
|
return (_InternalRPCPickler._py_rref_receiver, (rref_fork_data,))
|
|
|
|
def _rref_reducer(self, rref):
|
|
return self._py_rref_reducer(rref)
|
|
|
|
@classmethod
|
|
def _script_module_receiver(cls, script_module_serialized):
|
|
"""
|
|
Given a serialized representation of a ScriptModule created with torch.jit.save,
|
|
loads and returns the ScriptModule.
|
|
"""
|
|
f = io.BytesIO(script_module_serialized)
|
|
m = torch.jit.load(f)
|
|
return m
|
|
|
|
def _script_module_reducer(self, script_module):
|
|
"""
|
|
Serializes a ScriptModule.
|
|
"""
|
|
f = io.BytesIO()
|
|
torch.jit.save(script_module, f)
|
|
return (_InternalRPCPickler._script_module_receiver, (f.getvalue(),))
|
|
|
|
def serialize(self, obj):
|
|
r"""
|
|
Serialize non tensor data into binary string, tensor data into
|
|
tensor table
|
|
"""
|
|
f = io.BytesIO()
|
|
p = _pickler(f)
|
|
p.dispatch_table = self._dispatch_table
|
|
|
|
# rpc api could accept user picklers inheriting from _InternalRPCPickler to serialize rref,
|
|
# user picklers could have different initialization function from _InternalRPCPickler,
|
|
# but all the user picklers should call serialize() and use _rref_reducer to pickle rref
|
|
# in python. also, when _internal_rpc_pickler is imported to rpc/api.py, rpc.RRef is not
|
|
# compiled yet, it is not good place to access rpc.RRef inside _InternalRPCPickler constructor,
|
|
# so putting rref's dispatch table here
|
|
#
|
|
# The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`.
|
|
# The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`.
|
|
# Ignore type error because dispatch_table is defined in third-party package
|
|
p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index]
|
|
# An RRef created locally by RRef Python constructor is type of `rpc.RRef`.
|
|
# Ignore type error because dispatch_table is defined in third-party package
|
|
p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index]
|
|
|
|
# Add dispatch pickling for ScriptModule or its subclass.
|
|
if isinstance(obj, torch.jit.ScriptModule):
|
|
# Ignore type error because dispatch_table is defined in third-party package
|
|
p.dispatch_table[obj.__class__] = self._script_module_reducer # type: ignore[index]
|
|
|
|
# Install customized picklers.
|
|
for class_name in self._class_reducer_dict.keys():
|
|
p.dispatch_table[class_name] = self._class_reducer_dict[class_name] # type: ignore[index]
|
|
|
|
# save _thread_local_tensor_tables.send_tables if it is in nested call
|
|
global _thread_local_tensor_tables
|
|
if hasattr(_thread_local_tensor_tables, "send_tables"):
|
|
old_send_tables = _thread_local_tensor_tables.send_tables
|
|
else:
|
|
old_send_tables = None
|
|
_thread_local_tensor_tables.send_tables = []
|
|
|
|
p.dump(obj)
|
|
|
|
# restore _thread_local_tensor_tables.send_tables if return
|
|
# from nested call, otherwise clean up the table
|
|
tensors = _thread_local_tensor_tables.send_tables
|
|
if old_send_tables is not None:
|
|
_thread_local_tensor_tables.send_tables = old_send_tables
|
|
else:
|
|
del _thread_local_tensor_tables.send_tables
|
|
|
|
return (f.getvalue(), tensors)
|
|
|
|
def deserialize(self, binary_data, tensor_table):
|
|
r"""
|
|
Deserialize binary string + tensor table to original obj
|
|
"""
|
|
# save _thread_local_tensor_tables.recv_tables if it is in nested call
|
|
global _thread_local_tensor_tables
|
|
if hasattr(_thread_local_tensor_tables, "recv_tables"):
|
|
old_recv_tables = _thread_local_tensor_tables.recv_tables
|
|
else:
|
|
old_recv_tables = None
|
|
_thread_local_tensor_tables.recv_tables = tensor_table
|
|
|
|
try:
|
|
unpickler = _unpickler(io.BytesIO(binary_data))
|
|
ret = unpickler.load()
|
|
except AttributeError as e:
|
|
# Occurs when function is not found on module/class during
|
|
# unpickling.
|
|
except_str = (
|
|
str(e)
|
|
+ """ Default RPC pickler does not serialize
|
|
function code. Ensure that UDFs are defined on both caller and
|
|
callee modules."""
|
|
)
|
|
ret = AttributeError(except_str)
|
|
# Ensure the stack trace gets preserved
|
|
ret.__cause__ = e
|
|
|
|
# restore _thread_local_tensor_tables.recv_tables if return
|
|
# from nested call, otherwise clean up the table
|
|
if old_recv_tables is not None:
|
|
_thread_local_tensor_tables.recv_tables = old_recv_tables
|
|
else:
|
|
del _thread_local_tensor_tables.recv_tables
|
|
|
|
return ret
|
|
|
|
|
|
# Create _internal_rpc_pickler only once to initialize _dispatch_table only once
|
|
_internal_rpc_pickler = _InternalRPCPickler()
|
|
|
|
|
|
def serialize(obj):
|
|
return _internal_rpc_pickler.serialize(obj)
|
|
|
|
|
|
def deserialize(binary_data, tensor_table):
|
|
return _internal_rpc_pickler.deserialize(binary_data, tensor_table)
|
|
|
|
|
|
def _run_function(python_udf):
|
|
r"""
|
|
This function is exclusively called from C++.
|
|
See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.
|
|
|
|
Runs a Python UDF and returns its return value.
|
|
Wraps any exception in ``RemoteException`` if the function raises.
|
|
"""
|
|
try:
|
|
if isinstance(python_udf, AttributeError):
|
|
raise python_udf
|
|
result = python_udf.func(*python_udf.args, **python_udf.kwargs)
|
|
except Exception as e:
|
|
# except str = exception info + traceback string
|
|
except_str = (
|
|
f"On {_get_current_rpc_agent().get_worker_info()}:\n"
|
|
f"{repr(e)}\n{traceback.format_exc()}"
|
|
)
|
|
print(except_str, file=sys.stderr)
|
|
result = RemoteException(except_str, type(e))
|
|
return result
|
|
|
|
|
|
def _handle_exception(result):
|
|
if isinstance(result, RemoteException):
|
|
exception_msg = result.msg.encode("utf-8").decode("unicode_escape")
|
|
# We wrap exception re-creation here in case some exception classes
|
|
# cannot be constructed directly from a string.
|
|
exc = None
|
|
try:
|
|
exc = result.exception_type(exception_msg)
|
|
except BaseException as e:
|
|
raise RuntimeError( # noqa: B904
|
|
f"Failed to create original exception type. Error msg was {str(e)}"
|
|
f" Original exception on remote side was {exception_msg}"
|
|
) from e
|
|
|
|
if exc is not None:
|
|
raise exc
|
|
|
|
|
|
def _build_rpc_profiling_key(
|
|
exec_type, func_name, current_worker_name, dst_worker_name
|
|
):
|
|
"""
|
|
Builds the key that RPC calls are profiled with using the autograd profiler.
|
|
This will be the name of the corresponding Event recorded in the profiler.
|
|
|
|
Args:
|
|
exec_type (RPCExecMode): Type of RPC/RRef call
|
|
func_name (str): Name of function being profiled.
|
|
current_worker_name (str): Name of current worker.
|
|
dst_worker_name (str): Name of the destination worker.
|
|
|
|
Returns:
|
|
String representing profiling key
|
|
"""
|
|
profile_key = f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})"
|
|
return profile_key
|
|
|
|
|
|
def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name):
|
|
"""
|
|
This function should be called from RPC/RRef functions to create a
|
|
RecordFunction object for profiling. This function also runs the before
|
|
callbacks that start the profiling, though the user is responsible for
|
|
running the appropriate callbacks when the function to be profiled finishes.
|
|
|
|
Args:
|
|
exec_type (RPCExecMode): Type of RPC/RRef call
|
|
func_name (str): Name of function being profiled.
|
|
current_worker_name (str): Name of current worker.
|
|
dest_worker_name (str): Name of the destination worker.
|
|
|
|
Returns:
|
|
An instance of `torch.autograd._RecordFunction`.
|
|
"""
|
|
assert torch.autograd._profiler_enabled(), "Autograd profiler should be enabled."
|
|
profile_key = f"rpc_{exec_type.value}#{str(func_name)}({current_worker_name} -> {dest_worker_name})"
|
|
rf = torch.autograd._RecordFunction() # type: ignore[attr-defined]
|
|
torch.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined]
|
|
return rf
|
|
|
|
|
|
PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"])
|
|
RemoteException = collections.namedtuple("RemoteException", ["msg", "exception_type"])
|