948 lines
36 KiB
Python
948 lines
36 KiB
Python
|
__all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync",
|
||
|
"rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"]
|
||
|
|
||
|
import collections
|
||
|
import contextlib
|
||
|
import functools
|
||
|
import inspect
|
||
|
import logging
|
||
|
import threading
|
||
|
from typing import Dict, Generic, TypeVar, Set, Any, TYPE_CHECKING
|
||
|
|
||
|
import torch
|
||
|
from torch.futures import Future
|
||
|
|
||
|
from torch._C._distributed_rpc import (
|
||
|
PyRRef,
|
||
|
RemoteProfilerManager,
|
||
|
WorkerInfo,
|
||
|
TensorPipeAgent,
|
||
|
get_rpc_timeout,
|
||
|
_cleanup_python_rpc_handler,
|
||
|
_delete_all_user_and_unforked_owner_rrefs,
|
||
|
_destroy_rref_context,
|
||
|
_get_current_rpc_agent,
|
||
|
_invoke_remote_builtin,
|
||
|
_invoke_remote_python_udf,
|
||
|
_invoke_remote_torchscript,
|
||
|
_invoke_rpc_builtin,
|
||
|
_invoke_rpc_python_udf,
|
||
|
_invoke_rpc_torchscript,
|
||
|
_is_current_rpc_agent_set,
|
||
|
_reset_current_rpc_agent,
|
||
|
_set_and_start_rpc_agent,
|
||
|
)
|
||
|
|
||
|
from .internal import (
|
||
|
PythonUDF,
|
||
|
RPCExecMode,
|
||
|
_internal_rpc_pickler,
|
||
|
_build_rpc_profiling_key,
|
||
|
)
|
||
|
|
||
|
from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
|
||
|
|
||
|
from ._utils import _group_membership_management, _update_group_membership
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
# NB: Ignoring RRef leaks during shutdown. Without this, applications have to
|
||
|
# make sure there is no references to any RRef in the application code and
|
||
|
# Python GC has done its job to delete those RRefs. This is could result in bad
|
||
|
# debugging experiences especially when for large applications. Therefore, by
|
||
|
# default, we are going to ignore RRef leaks during shutdown. This is usually
|
||
|
# fine as shutdown means applications have done training and no longer care
|
||
|
# about states.
|
||
|
#
|
||
|
# To enable RRef leak checking, set this _ignore_rref_leak to False
|
||
|
_ignore_rref_leak = True
|
||
|
_default_pickler = _internal_rpc_pickler
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def _use_rpc_pickler(rpc_pickler):
|
||
|
r"""
|
||
|
rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler
|
||
|
"""
|
||
|
global _default_pickler
|
||
|
_default_pickler = rpc_pickler
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
_default_pickler = _internal_rpc_pickler
|
||
|
|
||
|
|
||
|
def _require_initialized(func):
|
||
|
@functools.wraps(func)
|
||
|
def wrapper(*args, **kwargs):
|
||
|
if not _is_current_rpc_agent_set():
|
||
|
raise RuntimeError(
|
||
|
"RPC has not been initialized. Call "
|
||
|
"torch.distributed.rpc.init_rpc first."
|
||
|
)
|
||
|
return func(*args, **kwargs)
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
class AllGatherStates:
|
||
|
def __init__(self):
|
||
|
# Each `gathered_objects` is an empty dict at beginning.
|
||
|
# The leader worker is elected as the first worker in a sorted worker
|
||
|
# name list. Whenever there is a worker entering `_all_gather()`, it
|
||
|
# runs `_gather_to_leader()` on the leader to add its own name and
|
||
|
# data obj to this dict. The leader also adds itself's name to the dict
|
||
|
# on calling `_all_gather()`.
|
||
|
# Once `set(gathered_objects.keys()) == _ALL_WORKER_NAMES`, the leader
|
||
|
# will broadcast the gathered dict to all follower workers and set their
|
||
|
# `gathered_objects` field and the `proceed_signal` field.
|
||
|
self.gathered_objects = {}
|
||
|
# All workers wait on this signal until it receives all gathered
|
||
|
# objects.
|
||
|
self.proceed_signal = threading.Event()
|
||
|
|
||
|
|
||
|
# States used by `def _all_gather()`.
|
||
|
# `_ALL_WORKER_NAMES` is initialized on initializing RPC layer.
|
||
|
_ALL_WORKER_NAMES: Set[Any] = set()
|
||
|
_all_gather_dict_lock = threading.RLock()
|
||
|
_all_gather_sequence_id: Dict[str, int] = {}
|
||
|
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates)
|
||
|
|
||
|
|
||
|
def _init_rpc_states(agent):
|
||
|
worker_infos = agent.get_worker_infos()
|
||
|
global _ALL_WORKER_NAMES
|
||
|
_ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos}
|
||
|
|
||
|
# NB: backend implementation might have already set the rpc_agent.
|
||
|
if not _is_current_rpc_agent_set():
|
||
|
_set_and_start_rpc_agent(agent)
|
||
|
|
||
|
|
||
|
def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
|
||
|
with _all_gather_dict_lock:
|
||
|
if not worker_names:
|
||
|
worker_names = _ALL_WORKER_NAMES
|
||
|
assert (
|
||
|
worker_name in worker_names
|
||
|
), f"{worker_name} is not expected by leader."
|
||
|
states = _all_gather_sequence_id_to_states[sequence_id]
|
||
|
assert (
|
||
|
worker_name not in states.gathered_objects
|
||
|
), f"{worker_name} reported intent sequence id {sequence_id} twice. "
|
||
|
states.gathered_objects[worker_name] = obj
|
||
|
if worker_names == set(states.gathered_objects.keys()):
|
||
|
states.proceed_signal.set()
|
||
|
|
||
|
|
||
|
def _broadcast_to_followers(sequence_id, objects_map):
|
||
|
with _all_gather_dict_lock:
|
||
|
states = _all_gather_sequence_id_to_states[sequence_id]
|
||
|
|
||
|
assert (
|
||
|
not states.proceed_signal.is_set()
|
||
|
), f"Termination signal sequence id {sequence_id} got set twice."
|
||
|
states.gathered_objects = objects_map
|
||
|
states.proceed_signal.set()
|
||
|
|
||
|
_thread_local_var = threading.local()
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def _wait_all():
|
||
|
r"""
|
||
|
A context manager that collects all futures returned by ``rpc_async`` and
|
||
|
waits them on the context manager's exit; relieving the user of needing
|
||
|
to explicitly call wait.
|
||
|
|
||
|
|
||
|
Example::
|
||
|
>>> # xdoctest: +SKIP("distributed")
|
||
|
>>> # On worker 0:
|
||
|
>>> import torch
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||
|
>>> with rpc._wait_all():
|
||
|
>>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
|
||
|
>>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
|
||
|
>>> #fut_1 and fut_2 are waited on
|
||
|
"""
|
||
|
_thread_local_var.future_list = []
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
try:
|
||
|
torch.futures.wait_all(_thread_local_var.future_list)
|
||
|
finally:
|
||
|
del _thread_local_var.future_list
|
||
|
|
||
|
|
||
|
@_require_initialized
|
||
|
def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
|
||
|
r"""
|
||
|
This is similar to torch.distributed.all_gather(), but is using RPC. It
|
||
|
picks the worker with the smallest name (alphabetic order) as the leader.
|
||
|
Then all followers send their data ``obj`` to the leader. After the leader
|
||
|
has received all, it will broadcast the results back to all followers. This
|
||
|
function blocks until all workers have received the gathered results.
|
||
|
"""
|
||
|
if not worker_names:
|
||
|
assert (
|
||
|
_ALL_WORKER_NAMES is not None
|
||
|
), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
|
||
|
worker_names = _ALL_WORKER_NAMES
|
||
|
leader_name = min(worker_names)
|
||
|
|
||
|
self_name = _get_current_rpc_agent().get_worker_info().name
|
||
|
|
||
|
with _all_gather_dict_lock:
|
||
|
concat_names = "".join(sorted(worker_names))
|
||
|
sequence_num = _all_gather_sequence_id.get(concat_names, 0)
|
||
|
_all_gather_sequence_id[concat_names] = sequence_num + 1
|
||
|
sequence_id = concat_names + str(sequence_num)
|
||
|
|
||
|
is_leader = leader_name == self_name
|
||
|
|
||
|
if timeout == UNSET_RPC_TIMEOUT:
|
||
|
# Timeout is specified by agent for RPC calls
|
||
|
rpc_timeout = get_rpc_timeout()
|
||
|
# No timeout for signal
|
||
|
signal_timeout = None
|
||
|
elif timeout == DEFAULT_SHUTDOWN_TIMEOUT:
|
||
|
# No timeout for RPC
|
||
|
rpc_timeout = timeout
|
||
|
# No timeout for signal
|
||
|
signal_timeout = None
|
||
|
else:
|
||
|
# Signal and RPC timeout use the same timeout
|
||
|
signal_timeout = rpc_timeout = timeout
|
||
|
|
||
|
# Phase 1: Followers send it's object to the leader
|
||
|
if is_leader:
|
||
|
_gather_to_leader(sequence_id, self_name, obj, worker_names)
|
||
|
else:
|
||
|
rpc_sync(
|
||
|
leader_name,
|
||
|
_gather_to_leader,
|
||
|
args=(sequence_id, self_name, obj, worker_names),
|
||
|
timeout=rpc_timeout,
|
||
|
)
|
||
|
|
||
|
with _all_gather_dict_lock:
|
||
|
states = _all_gather_sequence_id_to_states[sequence_id]
|
||
|
|
||
|
# Timeout is either set by function parameter or None (which is indefinite)
|
||
|
states.proceed_signal.wait(timeout=signal_timeout)
|
||
|
|
||
|
# Phase 2: Leader broadcast gathered results to all followers
|
||
|
# Leader's signal is the first to be unblocked, after receiving all
|
||
|
# followers' data objects.
|
||
|
if is_leader:
|
||
|
worker_name_to_response_future_dict = {}
|
||
|
for follower_name in worker_names - {leader_name}:
|
||
|
fut = rpc_async(
|
||
|
follower_name,
|
||
|
_broadcast_to_followers,
|
||
|
args=(sequence_id, states.gathered_objects),
|
||
|
timeout=rpc_timeout
|
||
|
)
|
||
|
worker_name_to_response_future_dict[follower_name] = fut
|
||
|
|
||
|
errors = []
|
||
|
for follower_name, fut in worker_name_to_response_future_dict.items():
|
||
|
try:
|
||
|
fut.wait()
|
||
|
except RuntimeError as ex:
|
||
|
errors.append((follower_name, ex))
|
||
|
|
||
|
if errors:
|
||
|
raise RuntimeError(
|
||
|
f"Followers {[e[0] for e in errors]} timed out in _all_gather "
|
||
|
f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}"
|
||
|
)
|
||
|
|
||
|
# Clean up for the states using the sequence_id
|
||
|
with _all_gather_dict_lock:
|
||
|
states = _all_gather_sequence_id_to_states.pop(sequence_id)
|
||
|
return states.gathered_objects
|
||
|
|
||
|
|
||
|
@_require_initialized
|
||
|
def _barrier(worker_names):
|
||
|
r"""
|
||
|
Synchronizes local and remote RPC processes.
|
||
|
|
||
|
This will block until all local and remote RPC processes specified under worker_names
|
||
|
reach this method to wait for all outstanding work to complete.
|
||
|
|
||
|
Args:
|
||
|
worker_names (List[str]): The set of workers to synchronize.
|
||
|
|
||
|
"""
|
||
|
try:
|
||
|
_all_gather(None, set(worker_names))
|
||
|
except RuntimeError as ex:
|
||
|
logger.error(
|
||
|
"Failed to complete barrier, got error %s", ex
|
||
|
)
|
||
|
|
||
|
|
||
|
@_require_initialized
|
||
|
def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT):
|
||
|
r"""
|
||
|
Block until all local and remote RPC processes reach this method and wait
|
||
|
for all outstanding work to complete. Every RPC process must call this
|
||
|
method before exit to perform a graceful shutdown. This should be used to
|
||
|
terminate the RPC framework, and there is no guarantee that the RPC
|
||
|
framework will work after this method returns.
|
||
|
"""
|
||
|
try:
|
||
|
_all_gather(None, timeout=timeout)
|
||
|
except RuntimeError as ex:
|
||
|
logger.error(
|
||
|
"Failed to respond to 'Shutdown Proceed' in time, got error %s", ex
|
||
|
)
|
||
|
raise ex
|
||
|
|
||
|
|
||
|
@_require_initialized
|
||
|
def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
|
||
|
r"""
|
||
|
Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
|
||
|
stops the local agent from accepting outstanding requests, and shuts
|
||
|
down the RPC framework by terminating all RPC threads. If ``graceful=True``,
|
||
|
this will block until all local and remote RPC processes reach this method
|
||
|
and wait for all outstanding work to complete. Otherwise, if
|
||
|
``graceful=False``, this is a local shutdown, and it does not wait for other
|
||
|
RPC processes to reach this method.
|
||
|
|
||
|
.. warning::
|
||
|
For :class:`~torch.futures.Future` objects returned by
|
||
|
:meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not
|
||
|
be called after ``shutdown()``.
|
||
|
|
||
|
Args:
|
||
|
graceful (bool): Whether to do a graceful shutdown or not. If True,
|
||
|
this will 1) wait until there is no pending system
|
||
|
messages for ``UserRRefs`` and delete them; 2) block
|
||
|
until all local and remote RPC processes have reached
|
||
|
this method and wait for all outstanding work to
|
||
|
complete.
|
||
|
|
||
|
Example::
|
||
|
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
|
||
|
on both workers. Refer to :meth:`~torch.distributed.init_process_group`
|
||
|
API for more details. For example,
|
||
|
|
||
|
export MASTER_ADDR=localhost
|
||
|
export MASTER_PORT=5678
|
||
|
|
||
|
Then run the following code in two different processes:
|
||
|
|
||
|
>>> # xdoctest: +SKIP
|
||
|
>>> # On worker 0:
|
||
|
>>> import torch
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||
|
>>> # do some work
|
||
|
>>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
|
||
|
>>> # ready to shutdown
|
||
|
>>> rpc.shutdown()
|
||
|
|
||
|
>>> # On worker 1:
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||
|
>>> # wait for worker 0 to finish work, and then shutdown.
|
||
|
>>> rpc.shutdown()
|
||
|
"""
|
||
|
if graceful:
|
||
|
try:
|
||
|
agent = _get_current_rpc_agent()
|
||
|
if not isinstance(agent, TensorPipeAgent) or agent.is_static_group:
|
||
|
_wait_all_workers(timeout)
|
||
|
_delete_all_user_and_unforked_owner_rrefs()
|
||
|
agent.join(shutdown=True, timeout=timeout)
|
||
|
else:
|
||
|
# This is a dynamic group so we need to grab the token for the operation
|
||
|
my_worker_info = agent.get_worker_info()
|
||
|
my_name = my_worker_info.name
|
||
|
with _group_membership_management(agent.store, my_name, False):
|
||
|
all_worker_infos = agent.get_worker_infos()
|
||
|
for worker in all_worker_infos:
|
||
|
if worker.name != my_name:
|
||
|
rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False))
|
||
|
agent.join(shutdown=True, timeout=timeout)
|
||
|
finally:
|
||
|
# In case of errors, continue to complete the local shutdown.
|
||
|
_finalize_shutdown()
|
||
|
else:
|
||
|
_finalize_shutdown()
|
||
|
|
||
|
|
||
|
def _finalize_shutdown():
|
||
|
try:
|
||
|
# This raises a `TORCH_CHECK()` exception on RRef leak detected.
|
||
|
_destroy_rref_context(_ignore_rref_leak)
|
||
|
finally:
|
||
|
_get_current_rpc_agent().shutdown()
|
||
|
# clean up python rpc handler in shutdown(), see comments in
|
||
|
# PythonRpcHandler::cleanup(), call it in python API because the
|
||
|
# cleanup() function has python dependency, it assumes python
|
||
|
# interpreter exists.
|
||
|
# No matter if RRef leak exception is raised, this clean-up code
|
||
|
# must run to avoid destruction segfault in Python 3.5.
|
||
|
#
|
||
|
# future.wait() should not be called after shutdown().
|
||
|
# pythonRpcHandler is cleaned up in shutdown(), after
|
||
|
# shutdown(), python objects returned from rpc python call can not be
|
||
|
# resolved.
|
||
|
_cleanup_python_rpc_handler()
|
||
|
_reset_current_rpc_agent()
|
||
|
|
||
|
|
||
|
@_require_initialized
|
||
|
def get_worker_info(worker_name=None):
|
||
|
r"""
|
||
|
Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name.
|
||
|
Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an
|
||
|
expensive string on every invocation.
|
||
|
|
||
|
Args:
|
||
|
worker_name (str): the string name of a worker. If ``None``, return the
|
||
|
the id of the current worker. (default ``None``)
|
||
|
|
||
|
Returns:
|
||
|
:class:`~torch.distributed.rpc.WorkerInfo` instance for the given
|
||
|
``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the
|
||
|
current worker if ``worker_name`` is ``None``.
|
||
|
"""
|
||
|
if worker_name is not None:
|
||
|
return _get_current_rpc_agent().get_worker_info(worker_name)
|
||
|
else:
|
||
|
return _get_current_rpc_agent().get_worker_info()
|
||
|
|
||
|
|
||
|
def _to_worker_info(to):
|
||
|
if isinstance(to, WorkerInfo):
|
||
|
return to
|
||
|
elif isinstance(to, (str, int)):
|
||
|
return get_worker_info(to)
|
||
|
else:
|
||
|
raise ValueError(f"Cannot get WorkerInfo from name {to}")
|
||
|
|
||
|
|
||
|
def _rref_typeof_on_owner(rref, blocking: bool = True):
|
||
|
rref_type = type(rref.local_value())
|
||
|
if blocking:
|
||
|
return rref_type
|
||
|
else:
|
||
|
# Wrap result into a completed Future. This is so that if blocking=`False`
|
||
|
# is specified, we return a future regardless of if this call is on user
|
||
|
# or owner.
|
||
|
future = Future[type]()
|
||
|
future.set_result(rref_type)
|
||
|
return future
|
||
|
|
||
|
|
||
|
def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True):
|
||
|
fut = rpc_async(
|
||
|
rref.owner(),
|
||
|
_rref_typeof_on_owner,
|
||
|
args=(rref,),
|
||
|
timeout=timeout
|
||
|
)
|
||
|
if blocking:
|
||
|
return fut.wait()
|
||
|
else:
|
||
|
return fut
|
||
|
|
||
|
|
||
|
T = TypeVar("T")
|
||
|
GenericWithOneTypeVar = Generic[T]
|
||
|
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
class RRef(PyRRef[T], Generic[T]):
|
||
|
pass
|
||
|
else:
|
||
|
try:
|
||
|
# Combine the implementation class and the type class.
|
||
|
class RRef(PyRRef, Generic[T]):
|
||
|
pass
|
||
|
except TypeError:
|
||
|
# TypeError: metaclass conflict: the metaclass of a derived class
|
||
|
# must be a (non-strict) subclass of the metaclasses of all its bases
|
||
|
# Mypy doesn't understand __class__ (mypy bug #4177)
|
||
|
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type]
|
||
|
pass
|
||
|
|
||
|
# Combine the implementation class and the type class.
|
||
|
# Types for classes expecting a certain generic parameter (mypy bug #7791)
|
||
|
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type]
|
||
|
pass
|
||
|
|
||
|
|
||
|
# Install docstrings from `PyRRef` to `RRef`.
|
||
|
#
|
||
|
# This is for the fact that pybind11 generates the parameter
|
||
|
# `self` as type `rpc.PyRRef`, so a `:inherited-members:`
|
||
|
# under `.. autoclass:: RRef` does not work.
|
||
|
# we have to do the following process to replace `rpc.PyRRef` with `rpc.RRef`.
|
||
|
#
|
||
|
def method_factory(method_name, docstring):
|
||
|
def method(self, *args, **kwargs):
|
||
|
return getattr(super(RRef, self), method_name)(*args, **kwargs)
|
||
|
|
||
|
if method.__doc__:
|
||
|
method.__doc__ = docstring
|
||
|
return method
|
||
|
|
||
|
|
||
|
for method_name, method in inspect.getmembers(PyRRef):
|
||
|
# Ignore magic methods, except "__str__".
|
||
|
if method_name.startswith("_") and method_name != "__str__":
|
||
|
continue
|
||
|
|
||
|
# Get pybind11 generated docstring.
|
||
|
# It's like,
|
||
|
"""
|
||
|
to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object
|
||
|
|
||
|
Blocking call that copies the value of the RRef from the owner
|
||
|
to the local node and returns it. If the current node is the
|
||
|
owner, returns a reference to the local value.
|
||
|
"""
|
||
|
docstring = getattr(method, "__doc__", None)
|
||
|
assert docstring is not None, "RRef user-facing methods should all have docstrings."
|
||
|
|
||
|
# Do surgery on pybind11 generated docstrings.
|
||
|
docstring = docstring.replace("torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef")
|
||
|
|
||
|
# Attach user-facing RRef method with modified docstring.
|
||
|
new_method = method_factory(method_name, docstring)
|
||
|
setattr(RRef, method_name, new_method)
|
||
|
|
||
|
|
||
|
@_require_initialized
|
||
|
def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||
|
r"""
|
||
|
Make a remote call to run ``func`` on worker ``to`` and return an
|
||
|
:class:`~torch.distributed.rpc.RRef` to the result value immediately.
|
||
|
Worker ``to`` will be the owner of the returned
|
||
|
:class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is
|
||
|
a user. The owner manages the global reference count of its
|
||
|
:class:`~torch.distributed.rpc.RRef`, and the owner
|
||
|
:class:`~torch.distributed.rpc.RRef` is only destructed when globally there
|
||
|
are no living references to it.
|
||
|
|
||
|
Args:
|
||
|
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
|
||
|
func (Callable): a callable function, such as Python callables, builtin
|
||
|
operators (e.g. :meth:`~torch.add`) and annotated
|
||
|
TorchScript functions.
|
||
|
args (tuple): the argument tuple for the ``func`` invocation.
|
||
|
kwargs (dict): is a dictionary of keyword arguments for the ``func``
|
||
|
invocation.
|
||
|
|
||
|
timeout (float, optional): timeout in seconds for this remote call. If the
|
||
|
creation of this
|
||
|
:class:`~torch.distributed.rpc.RRef` on worker
|
||
|
``to`` is not successfully processed on this
|
||
|
worker within this timeout, then the next time
|
||
|
there is an attempt to use the RRef (such as
|
||
|
``to_here()``), a timeout will be raised
|
||
|
indicating this failure. A value of 0 indicates
|
||
|
an infinite timeout, i.e. a timeout error will
|
||
|
never be raised. If not provided, the default
|
||
|
value set during initialization or with
|
||
|
``_set_rpc_timeout`` is used.
|
||
|
|
||
|
Returns:
|
||
|
A user :class:`~torch.distributed.rpc.RRef` instance to the result
|
||
|
value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here`
|
||
|
to retrieve the result value locally.
|
||
|
|
||
|
.. warning ::
|
||
|
The ``remote`` API does not copy storages of argument tensors until
|
||
|
sending them over the wire, which could be done by a different thread
|
||
|
depending on the RPC backend type. The caller should make sure that the
|
||
|
contents of those tensors stay intact until the returned RRef is
|
||
|
confirmed by the owner, which can be checked using the
|
||
|
:meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API.
|
||
|
|
||
|
.. warning ::
|
||
|
Errors such as timeouts for the ``remote`` API are handled on a
|
||
|
best-effort basis. This means that when remote calls initiated by
|
||
|
``remote`` fail, such as with a timeout error, we take a best-effort
|
||
|
approach to error handling. This means that errors are handled and set
|
||
|
on the resulting RRef on an asynchronous basis. If the RRef has not been
|
||
|
used by the application before this handling (such as ``to_here`` or
|
||
|
fork call), then future uses of the ``RRef`` will appropriately raise
|
||
|
errors. However, it is possible that the user application will use the
|
||
|
``RRef`` before the errors are handled. In this case, errors may not be
|
||
|
raised as they have not yet been handled.
|
||
|
|
||
|
Example::
|
||
|
|
||
|
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
|
||
|
on both workers. Refer to :meth:`~torch.distributed.init_process_group`
|
||
|
API for more details. For example,
|
||
|
|
||
|
export MASTER_ADDR=localhost
|
||
|
export MASTER_PORT=5678
|
||
|
|
||
|
Then run the following code in two different processes:
|
||
|
|
||
|
>>> # xdoctest: +SKIP
|
||
|
>>> # On worker 0:
|
||
|
>>> import torch
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||
|
>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
|
||
|
>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
|
||
|
>>> x = rref1.to_here() + rref2.to_here()
|
||
|
>>> rpc.shutdown()
|
||
|
|
||
|
>>> # On worker 1:
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||
|
>>> rpc.shutdown()
|
||
|
|
||
|
Below is an example of running a TorchScript function using RPC.
|
||
|
|
||
|
>>> # On both workers:
|
||
|
>>> @torch.jit.script
|
||
|
>>> def my_script_add(tensor: torch.Tensor, scalar: int):
|
||
|
>>> return torch.add(tensor, scalar)
|
||
|
|
||
|
>>> # On worker 0:
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||
|
>>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3))
|
||
|
>>> rref.to_here()
|
||
|
>>> rpc.shutdown()
|
||
|
|
||
|
>>> # On worker 1:
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||
|
>>> rpc.shutdown()
|
||
|
"""
|
||
|
torch._C._log_api_usage_once("torch.distributed.rpc_remote")
|
||
|
qualified_name = torch.jit._builtins._find_builtin(func)
|
||
|
dst_worker_info = _to_worker_info(to)
|
||
|
should_profile = _get_should_profile()
|
||
|
|
||
|
ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info)
|
||
|
|
||
|
with ctx_manager as rf:
|
||
|
args = args if args else ()
|
||
|
kwargs = kwargs if kwargs else {}
|
||
|
|
||
|
is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
|
||
|
|
||
|
if is_async_exec:
|
||
|
wrapped = func._wrapped_async_rpc_function
|
||
|
if isinstance(wrapped, torch.jit.ScriptFunction):
|
||
|
func = wrapped
|
||
|
|
||
|
if qualified_name is not None:
|
||
|
rref = _invoke_remote_builtin(dst_worker_info, qualified_name, timeout, *args, **kwargs)
|
||
|
elif isinstance(func, torch.jit.ScriptFunction):
|
||
|
rref = _invoke_remote_torchscript(
|
||
|
dst_worker_info.name,
|
||
|
torch._jit_internal._qualified_name(func),
|
||
|
timeout,
|
||
|
is_async_exec,
|
||
|
*args,
|
||
|
**kwargs,
|
||
|
)
|
||
|
else:
|
||
|
(pickled_python_udf, tensors) = _default_pickler.serialize(
|
||
|
PythonUDF(func, args, kwargs)
|
||
|
)
|
||
|
rref = _invoke_remote_python_udf(
|
||
|
dst_worker_info,
|
||
|
pickled_python_udf,
|
||
|
tensors,
|
||
|
timeout,
|
||
|
is_async_exec
|
||
|
)
|
||
|
# attach profiling information
|
||
|
if should_profile:
|
||
|
assert torch.autograd._profiler_enabled()
|
||
|
assert rf is not None
|
||
|
fut = rf._call_end_callbacks_on_future(rref._get_future())
|
||
|
rref._set_profiling_future(fut)
|
||
|
|
||
|
return rref
|
||
|
|
||
|
|
||
|
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT):
|
||
|
if not callable(func):
|
||
|
raise TypeError("function should be callable.")
|
||
|
|
||
|
qualified_name = torch.jit._builtins._find_builtin(func)
|
||
|
dst_worker_info = _to_worker_info(to)
|
||
|
|
||
|
should_profile = _get_should_profile()
|
||
|
|
||
|
ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info)
|
||
|
|
||
|
with ctx_manager as rf:
|
||
|
args = args if args else ()
|
||
|
kwargs = kwargs if kwargs else {}
|
||
|
|
||
|
is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
|
||
|
|
||
|
if is_async_exec:
|
||
|
wrapped = func._wrapped_async_rpc_function
|
||
|
if isinstance(wrapped, torch.jit.ScriptFunction):
|
||
|
func = wrapped
|
||
|
|
||
|
if qualified_name is not None:
|
||
|
fut = _invoke_rpc_builtin(
|
||
|
dst_worker_info,
|
||
|
qualified_name,
|
||
|
rpc_timeout,
|
||
|
*args,
|
||
|
**kwargs
|
||
|
)
|
||
|
elif isinstance(func, torch.jit.ScriptFunction):
|
||
|
fut = _invoke_rpc_torchscript(
|
||
|
dst_worker_info.name,
|
||
|
torch._jit_internal._qualified_name(func),
|
||
|
args,
|
||
|
kwargs,
|
||
|
rpc_timeout,
|
||
|
is_async_exec
|
||
|
)
|
||
|
else:
|
||
|
(pickled_python_udf, tensors) = _default_pickler.serialize(
|
||
|
PythonUDF(func, args, kwargs)
|
||
|
)
|
||
|
fut = _invoke_rpc_python_udf(
|
||
|
dst_worker_info,
|
||
|
pickled_python_udf,
|
||
|
tensors,
|
||
|
rpc_timeout,
|
||
|
is_async_exec
|
||
|
)
|
||
|
if should_profile:
|
||
|
assert torch.autograd._profiler_enabled()
|
||
|
assert rf is not None
|
||
|
# Schedule profiling callbacks to run when the future completes.
|
||
|
# This returns a future that is completed when the original future
|
||
|
# completes and the profiling callbacks have been completed as well,
|
||
|
# to guarantee that fut.wait() completes the profiling. This new
|
||
|
# future will contain the same value as the original future.
|
||
|
fut = rf._call_end_callbacks_on_future(fut)
|
||
|
return fut
|
||
|
|
||
|
|
||
|
@_require_initialized
|
||
|
def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT):
|
||
|
r"""
|
||
|
Make a blocking RPC call to run function ``func`` on worker ``to``. RPC
|
||
|
messages are sent and received in parallel to execution of Python code. This
|
||
|
method is thread-safe.
|
||
|
|
||
|
Args:
|
||
|
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
|
||
|
func (Callable): a callable function, such as Python callables, builtin
|
||
|
operators (e.g. :meth:`~torch.add`) and annotated
|
||
|
TorchScript functions.
|
||
|
args (tuple): the argument tuple for the ``func`` invocation.
|
||
|
kwargs (dict): is a dictionary of keyword arguments for the ``func``
|
||
|
invocation.
|
||
|
timeout (float, optional): timeout in seconds to use for this RPC. If
|
||
|
the RPC does not complete in this amount of
|
||
|
time, an exception indicating it has
|
||
|
timed out will be raised. A value of 0
|
||
|
indicates an infinite timeout, i.e. a timeout
|
||
|
error will never be raised. If not provided,
|
||
|
the default value set during initialization
|
||
|
or with ``_set_rpc_timeout`` is used.
|
||
|
|
||
|
Returns:
|
||
|
Returns the result of running ``func`` with ``args`` and ``kwargs``.
|
||
|
|
||
|
Example::
|
||
|
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
|
||
|
on both workers. Refer to :meth:`~torch.distributed.init_process_group`
|
||
|
API for more details. For example,
|
||
|
|
||
|
export MASTER_ADDR=localhost
|
||
|
export MASTER_PORT=5678
|
||
|
|
||
|
Then run the following code in two different processes:
|
||
|
|
||
|
>>> # xdoctest: +SKIP
|
||
|
>>> # On worker 0:
|
||
|
>>> import torch
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||
|
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
|
||
|
>>> rpc.shutdown()
|
||
|
|
||
|
>>> # On worker 1:
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||
|
>>> rpc.shutdown()
|
||
|
|
||
|
Below is an example of running a TorchScript function using RPC.
|
||
|
|
||
|
>>> # On both workers:
|
||
|
>>> @torch.jit.script
|
||
|
>>> def my_script_add(tensor: torch.Tensor, scalar: int):
|
||
|
>>> return torch.add(tensor, scalar)
|
||
|
|
||
|
>>> # On worker 0:
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||
|
>>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3))
|
||
|
>>> rpc.shutdown()
|
||
|
|
||
|
>>> # On worker 1:
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||
|
>>> rpc.shutdown()
|
||
|
|
||
|
"""
|
||
|
torch._C._log_api_usage_once("torch.distributed.rpc_sync")
|
||
|
fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout)
|
||
|
return fut.wait()
|
||
|
|
||
|
|
||
|
@_require_initialized
|
||
|
def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||
|
r"""
|
||
|
Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC
|
||
|
messages are sent and received in parallel to execution of Python code. This
|
||
|
method is thread-safe. This method will immediately return a
|
||
|
:class:`~torch.futures.Future` that can be awaited on.
|
||
|
|
||
|
Args:
|
||
|
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
|
||
|
func (Callable): a callable function, such as Python callables, builtin
|
||
|
operators (e.g. :meth:`~torch.add`) and annotated
|
||
|
TorchScript functions.
|
||
|
args (tuple): the argument tuple for the ``func`` invocation.
|
||
|
kwargs (dict): is a dictionary of keyword arguments for the ``func``
|
||
|
invocation.
|
||
|
timeout (float, optional): timeout in seconds to use for this RPC. If
|
||
|
the RPC does not complete in this amount of
|
||
|
time, an exception indicating it has
|
||
|
timed out will be raised. A value of 0
|
||
|
indicates an infinite timeout, i.e. a timeout
|
||
|
error will never be raised. If not provided,
|
||
|
the default value set during initialization
|
||
|
or with ``_set_rpc_timeout`` is used.
|
||
|
|
||
|
|
||
|
Returns:
|
||
|
Returns a :class:`~torch.futures.Future` object that can be waited
|
||
|
on. When completed, the return value of ``func`` on ``args`` and
|
||
|
``kwargs`` can be retrieved from the :class:`~torch.futures.Future`
|
||
|
object.
|
||
|
|
||
|
.. warning ::
|
||
|
Using GPU tensors as arguments or return values of ``func`` is not
|
||
|
supported since we don't support sending GPU tensors over the wire. You
|
||
|
need to explicitly copy GPU tensors to CPU before using them as
|
||
|
arguments or return values of ``func``.
|
||
|
|
||
|
.. warning ::
|
||
|
The ``rpc_async`` API does not copy storages of argument tensors until
|
||
|
sending them over the wire, which could be done by a different thread
|
||
|
depending on the RPC backend type. The caller should make sure that the
|
||
|
contents of those tensors stay intact until the returned
|
||
|
:class:`~torch.futures.Future` completes.
|
||
|
|
||
|
Example::
|
||
|
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
|
||
|
on both workers. Refer to :meth:`~torch.distributed.init_process_group`
|
||
|
API for more details. For example,
|
||
|
|
||
|
export MASTER_ADDR=localhost
|
||
|
export MASTER_PORT=5678
|
||
|
|
||
|
Then run the following code in two different processes:
|
||
|
|
||
|
>>> # xdoctest: +SKIP
|
||
|
>>> # On worker 0:
|
||
|
>>> import torch
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||
|
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
|
||
|
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
|
||
|
>>> result = fut1.wait() + fut2.wait()
|
||
|
>>> rpc.shutdown()
|
||
|
|
||
|
>>> # On worker 1:
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||
|
>>> rpc.shutdown()
|
||
|
|
||
|
Below is an example of running a TorchScript function using RPC.
|
||
|
|
||
|
>>> # On both workers:
|
||
|
>>> @torch.jit.script
|
||
|
>>> def my_script_add(tensor: torch.Tensor, scalar: int):
|
||
|
>>> return torch.add(tensor, scalar)
|
||
|
|
||
|
>>> # On worker 0:
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||
|
>>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3))
|
||
|
>>> ret = fut.wait()
|
||
|
>>> rpc.shutdown()
|
||
|
|
||
|
>>> # On worker 1:
|
||
|
>>> import torch.distributed.rpc as rpc
|
||
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||
|
>>> rpc.shutdown()
|
||
|
"""
|
||
|
torch._C._log_api_usage_once("torch.distributed.rpc_async")
|
||
|
fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout)
|
||
|
if hasattr(_thread_local_var, "future_list"):
|
||
|
_thread_local_var.future_list.append(fut)
|
||
|
return fut
|
||
|
|
||
|
|
||
|
def _get_should_profile():
|
||
|
# Legacy profiler should be enabled. RPC profiling is not supported with
|
||
|
# Kineto profiler.
|
||
|
ActiveProfilerType = torch._C._profiler.ActiveProfilerType
|
||
|
return (
|
||
|
torch.autograd._profiler_enabled() and
|
||
|
torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
|
||
|
)
|
||
|
|
||
|
|
||
|
def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info):
|
||
|
ctx_manager = contextlib.nullcontext()
|
||
|
|
||
|
if should_profile:
|
||
|
# Create appropriate string representation based on type of func
|
||
|
# (builtin, script, python)
|
||
|
if qualified_name is None:
|
||
|
func_name = (
|
||
|
torch._jit_internal._qualified_name(func)
|
||
|
if isinstance(func, torch.jit.ScriptFunction)
|
||
|
else func.__qualname__
|
||
|
)
|
||
|
else:
|
||
|
func_name = qualified_name
|
||
|
# Build RPC profiling key.
|
||
|
rpc_profiling_key = _build_rpc_profiling_key(
|
||
|
rpc_type,
|
||
|
func_name,
|
||
|
get_worker_info().name,
|
||
|
dst_worker_info.name,
|
||
|
)
|
||
|
RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
|
||
|
# Mypy doesn't support re-def of a variable not in the same block (#1174)
|
||
|
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment]
|
||
|
|
||
|
return ctx_manager
|