284 lines
11 KiB
Python
284 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
import sys
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
|
|
from torch.distributed.elastic import events, metrics
|
|
from torch.distributed.elastic.agent.server.api import WorkerSpec
|
|
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
|
|
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, SignalException
|
|
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
|
|
from torch.distributed.elastic.rendezvous import RendezvousParameters
|
|
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
|
|
from torch.distributed.elastic.utils.logging import get_logger
|
|
|
|
__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent']
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class LaunchConfig:
|
|
"""
|
|
Creates a rendezvous config.
|
|
|
|
Args:
|
|
min_nodes: Minimum amount of nodes that the user function will
|
|
be launched on. Elastic agent ensures that the user
|
|
function start only when the min_nodes amount enters
|
|
the rendezvous.
|
|
max_nodes: Maximum amount of nodes that the user function
|
|
will be launched on.
|
|
nproc_per_node: On each node the elastic agent will launch
|
|
this amount of workers that will execute user
|
|
defined function.
|
|
rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd).
|
|
rdzv_endpoint: The endpoint of the rdzv sync. storage.
|
|
rdzv_configs: Key, value pair that specifies rendezvous specific configuration.
|
|
rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going
|
|
to be removed in future versions, see the note below. The default timeout is 900 seconds.
|
|
run_id: The unique run id of the job (if not passed a unique one will be
|
|
deduced from run environment - flow workflow id in flow - or auto generated).
|
|
role: User defined role of the worker (defaults to "trainer").
|
|
max_restarts: The maximum amount of restarts that elastic agent will conduct
|
|
on workers before failure.
|
|
monitor_interval: The interval in seconds that is used by the elastic_agent
|
|
as a period of monitoring workers.
|
|
start_method: The method is used by the elastic agent to start the
|
|
workers (spawn, fork, forkserver).
|
|
metrics_cfg: configuration to initialize metrics.
|
|
local_addr: address of the local node if any. If not set, a lookup on the local
|
|
machine's FQDN will be performed.
|
|
local_ranks_filter: ranks for which to show logs in console. If not set, show from all.
|
|
..note:
|
|
`rdzv_timeout` is a legacy argument that will be removed in future.
|
|
Set the timeout via `rdzv_configs['timeout']`
|
|
|
|
"""
|
|
|
|
min_nodes: int
|
|
max_nodes: int
|
|
nproc_per_node: int
|
|
logs_specs: Optional[LogsSpecs] = None
|
|
run_id: str = ""
|
|
role: str = "default_role"
|
|
rdzv_endpoint: str = ""
|
|
rdzv_backend: str = "etcd"
|
|
rdzv_configs: Dict[str, Any] = field(default_factory=dict)
|
|
rdzv_timeout: int = -1
|
|
max_restarts: int = 3
|
|
monitor_interval: float = 30
|
|
start_method: str = "spawn"
|
|
log_line_prefix_template: Optional[str] = None
|
|
metrics_cfg: Dict[str, str] = field(default_factory=dict)
|
|
local_addr: Optional[str] = None
|
|
|
|
def __post_init__(self):
|
|
default_timeout = 900
|
|
if self.rdzv_timeout != -1:
|
|
self.rdzv_configs["timeout"] = self.rdzv_timeout
|
|
elif "timeout" not in self.rdzv_configs:
|
|
self.rdzv_configs["timeout"] = default_timeout
|
|
|
|
# Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage
|
|
if self.logs_specs is None:
|
|
self.logs_specs = DefaultLogsSpecs()
|
|
|
|
|
|
class elastic_launch:
|
|
"""
|
|
Launches an torchelastic agent on the container that invoked the entrypoint.
|
|
|
|
1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
|
|
``entrypoint`` can be a function or a command.
|
|
2. The return value is a map of each worker's output mapped
|
|
by their respective global rank.
|
|
|
|
Usage
|
|
|
|
::
|
|
|
|
def worker_fn(foo):
|
|
# ...
|
|
|
|
def main():
|
|
# entrypoint is a function.
|
|
outputs = elastic_launch(LaunchConfig, worker_fn)(foo)
|
|
# return rank 0's output
|
|
return outputs[0]
|
|
|
|
# entrypoint is a command and ``script.py`` is the python module.
|
|
outputs = elastic_launch(LaunchConfig, "script.py")(args)
|
|
outputs = elastic_launch(LaunchConfig, "python")("script.py")
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: LaunchConfig,
|
|
entrypoint: Union[Callable, str, None],
|
|
):
|
|
self._config = config
|
|
self._entrypoint = entrypoint
|
|
|
|
def __call__(self, *args):
|
|
return launch_agent(self._config, self._entrypoint, list(args))
|
|
|
|
|
|
def _get_entrypoint_name(
|
|
entrypoint: Union[Callable, str, None], args: List[Any]
|
|
) -> str:
|
|
"""Retrieve entrypoint name with the rule:
|
|
1. If entrypoint is a function, use ``entrypoint.__qualname__``.
|
|
2. If entrypoint is a string, check its value:
|
|
2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args``
|
|
which does not start with hifen letter (for example, "-u" will be skipped).
|
|
2.2 otherwise, use ``entrypoint`` value.
|
|
3. Otherwise, return empty string.
|
|
"""
|
|
if isinstance(entrypoint, Callable): # type: ignore[arg-type]
|
|
return entrypoint.__name__ # type: ignore[union-attr]
|
|
elif isinstance(entrypoint, str):
|
|
if entrypoint == sys.executable:
|
|
return next((arg for arg in args if arg[0] != "-"), "")
|
|
else:
|
|
return entrypoint
|
|
else:
|
|
return ""
|
|
|
|
|
|
def _get_addr_and_port(
|
|
rdzv_parameters: RendezvousParameters,
|
|
) -> Tuple[Optional[str], Optional[int]]:
|
|
if rdzv_parameters.backend != "static":
|
|
return (None, None)
|
|
endpoint = rdzv_parameters.endpoint
|
|
endpoint = endpoint.strip()
|
|
if not endpoint:
|
|
raise ValueError(
|
|
"Endpoint is missing in endpoint. Try to add --master-addr and --master-port"
|
|
)
|
|
master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1)
|
|
if master_port == -1:
|
|
raise ValueError(
|
|
f"port is missing in endpoint: {endpoint}. Try to specify --master-port"
|
|
)
|
|
return (master_addr, master_port)
|
|
|
|
|
|
def launch_agent(
|
|
config: LaunchConfig,
|
|
entrypoint: Union[Callable, str, None],
|
|
args: List[Any],
|
|
) -> Dict[int, Any]:
|
|
if not config.run_id:
|
|
run_id = str(uuid.uuid4().int)
|
|
logger.warning("config has no run_id, generated a random run_id: %s", run_id)
|
|
config.run_id = run_id
|
|
|
|
entrypoint_name = _get_entrypoint_name(entrypoint, args)
|
|
|
|
logger.info(
|
|
"Starting elastic_operator with launch configs:\n"
|
|
" entrypoint : %(entrypoint)s\n"
|
|
" min_nodes : %(min_nodes)s\n"
|
|
" max_nodes : %(max_nodes)s\n"
|
|
" nproc_per_node : %(nproc_per_node)s\n"
|
|
" run_id : %(run_id)s\n"
|
|
" rdzv_backend : %(rdzv_backend)s\n"
|
|
" rdzv_endpoint : %(rdzv_endpoint)s\n"
|
|
" rdzv_configs : %(rdzv_configs)s\n"
|
|
" max_restarts : %(max_restarts)s\n"
|
|
" monitor_interval : %(monitor_interval)s\n"
|
|
" log_dir : %(log_dir)s\n"
|
|
" metrics_cfg : %(metrics_cfg)s\n",
|
|
{
|
|
"entrypoint": entrypoint_name,
|
|
"min_nodes": config.min_nodes,
|
|
"max_nodes": config.max_nodes,
|
|
"nproc_per_node": config.nproc_per_node,
|
|
"run_id": config.run_id,
|
|
"rdzv_backend": config.rdzv_backend,
|
|
"rdzv_endpoint": config.rdzv_endpoint,
|
|
"rdzv_configs": config.rdzv_configs,
|
|
"max_restarts": config.max_restarts,
|
|
"monitor_interval": config.monitor_interval,
|
|
"log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr]
|
|
"metrics_cfg": config.metrics_cfg
|
|
}
|
|
)
|
|
|
|
rdzv_parameters = RendezvousParameters(
|
|
backend=config.rdzv_backend,
|
|
endpoint=config.rdzv_endpoint,
|
|
run_id=config.run_id,
|
|
min_nodes=config.min_nodes,
|
|
max_nodes=config.max_nodes,
|
|
local_addr=config.local_addr,
|
|
**config.rdzv_configs,
|
|
)
|
|
|
|
master_addr, master_port = _get_addr_and_port(rdzv_parameters)
|
|
|
|
spec = WorkerSpec(
|
|
role=config.role,
|
|
local_world_size=config.nproc_per_node,
|
|
entrypoint=entrypoint,
|
|
args=tuple(args),
|
|
rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
|
|
max_restarts=config.max_restarts,
|
|
monitor_interval=config.monitor_interval,
|
|
master_addr=master_addr,
|
|
master_port=master_port,
|
|
local_addr=config.local_addr,
|
|
)
|
|
|
|
agent = LocalElasticAgent(
|
|
spec=spec,
|
|
logs_specs=config.logs_specs, # type: ignore[arg-type]
|
|
start_method=config.start_method,
|
|
log_line_prefix_template=config.log_line_prefix_template,
|
|
)
|
|
|
|
shutdown_rdzv = True
|
|
try:
|
|
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
|
|
|
|
result = agent.run()
|
|
# records that agent.run() has succeeded NOT that workers have succeeded
|
|
events.record(agent.get_event_succeeded())
|
|
|
|
if result.is_failed():
|
|
# ChildFailedError is treated specially by @record
|
|
# if the error files for the failed children exist
|
|
# @record will copy the first error (root cause)
|
|
# to the error file of the launcher process.
|
|
raise ChildFailedError(
|
|
name=entrypoint_name,
|
|
failures=result.failures,
|
|
)
|
|
|
|
return result.return_values
|
|
except ChildFailedError:
|
|
raise
|
|
except SignalException:
|
|
# when the agent dies with a signal do NOT shutdown the rdzv_handler
|
|
# since this closes the rendezvous on this rdzv_id permanently and
|
|
# prevents any additional scaling events
|
|
shutdown_rdzv = False
|
|
events.record(agent.get_event_failed())
|
|
raise
|
|
except Exception:
|
|
events.record(agent.get_event_failed())
|
|
raise
|
|
finally:
|
|
if shutdown_rdzv:
|
|
spec.rdzv_handler.shutdown()
|