Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/distributed.py
2023-06-19 00:49:18 +02:00

177 lines
7.1 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import logging
import os
from typing import Any, Optional, Union, Sequence
from jax._src import clusters
from jax._src.config import config
from jax._src.lib import xla_extension
logger = logging.getLogger(__name__)
class State:
process_id: int = 0
num_processes: int = 1
service: Optional[Any] = None
client: Optional[Any] = None
preemption_sync_manager: Optional[Any] = None
coordinator_address: Optional[str] = None
def initialize(self,
coordinator_address: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None,
local_device_ids: Optional[Union[int, Sequence[int]]] = None):
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
if isinstance(local_device_ids, int):
local_device_ids = [local_device_ids]
(coordinator_address, num_processes, process_id, local_device_ids) = (
clusters.ClusterEnv.auto_detect_unset_distributed_params(
coordinator_address, num_processes, process_id, local_device_ids
)
)
if coordinator_address is None:
raise ValueError('coordinator_address should be defined.')
if num_processes is None:
raise ValueError('Number of processes must be defined.')
if process_id is None:
raise ValueError('The process id of the current process must be defined.')
self.coordinator_address = coordinator_address
if local_device_ids:
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
logger.info('JAX distributed initialized with visible devices: %s', visible_devices)
config.update("jax_cuda_visible_devices", visible_devices)
config.update("jax_rocm_visible_devices", visible_devices)
self.process_id = process_id
if process_id == 0:
if self.service is not None:
raise RuntimeError('distributed.initialize should only be called once.')
logger.info('Starting JAX distributed service on %s', coordinator_address)
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes, config.jax_coordination_service)
self.num_processes = num_processes
if self.client is not None:
raise RuntimeError('distributed.initialize should only be called once.')
# Set init_timeout to 5 min to leave time for all the processes to connect
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, config.jax_coordination_service,
init_timeout=300)
logger.info('Connecting to JAX distributed service on %s', coordinator_address)
self.client.connect()
if config.jax_coordination_service:
self.initialize_preemption_sync_manager()
def shutdown(self):
if self.client:
self.client.shutdown()
self.client = None
if self.service:
self.service.shutdown()
self.service = None
if self.preemption_sync_manager:
self.preemption_sync_manager = None
def initialize_preemption_sync_manager(self):
if self.preemption_sync_manager is not None:
raise RuntimeError(
'Preemption sync manager should only be initialized once.')
self.preemption_sync_manager = (
xla_extension.create_preemption_sync_manager())
self.preemption_sync_manager.initialize(self.client)
global_state = State()
def initialize(coordinator_address: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None,
local_device_ids: Optional[Union[int, Sequence[int]]] = None):
"""Initializes the JAX distributed system.
Calling :func:`~jax.distributed.initialize` prepares JAX for execution on
multi-host GPU and Cloud TPU. :func:`~jax.distributed.initialize` must be
called before performing any JAX computations.
The JAX distributed system serves a number of roles:
* it allows JAX processes to discover each other and share topology information,
* it performs health checking, ensuring that all processes shut down if any process dies, and
* it is used for distributed checkpointing.
If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they
will be chosen automatically.
Otherwise, you must provide the ``coordinator_address``,
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
Args:
coordinator_address: the IP address of process `0` and a port on which that
process should launch a coordinator service. The choice of
port does not matter, so long as the port is available on the coordinator
and all processes agree on the port.
May be ``None`` only on supported environments, in which case it will be chosen automatically.
Note that special addresses like ``localhost`` or ``127.0.0.1`` usually mean that the program
will bind to a local interface and are not suitable when running in a multi-host environment.
num_processes: Number of processes. May be ``None`` only on supported environments, in
which case it will be chosen automatically.
process_id: The ID number of the current process. The ``process_id`` values across
the cluster must be a dense range ``0``, ``1``, ..., ``num_processes - 1``.
May be ``None`` only on supported environments; if ``None`` it will be chosen automatically.
local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``.
If ``None``, defaults to all local devices being visible to the process except when processes
are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.
Raises:
RuntimeError: If :func:`~jax.distributed.initialize` is called more than once.
Example:
Suppose there are two GPU processs, and process 0 is the designated coordinator
with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the
following commands before anything else.
On process 0:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0) # doctest: +SKIP
On process 1:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1) # doctest: +SKIP
"""
global_state.initialize(coordinator_address, num_processes, process_id, local_device_ids)
atexit.register(shutdown)
def shutdown():
"""Shuts down the distributed system.
Does nothing if the distributed system is not running."""
global_state.shutdown()