Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/clusters/cluster.py

111 lines
4.7 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Optional, Type, Sequence, Tuple
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
logger = logging.getLogger(__name__)
class ClusterEnv:
"""Interface for defining a cluster environment.
To enable auto bootrapping (aka :func:`jax.distributed.initialize()`),
cluster environments need to derive from :class:`ClusterEnv` and implement
:func:`is_env_present`, :func:`get_coordinator_address`,
:func:`get_process_count`, and :func:`get_process_id`.
:class:`ClusterEnv` subclasses are automatically detected when imported.
"""
_cluster_types: List[Type['ClusterEnv']] = []
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._cluster_types.append(cls)
@classmethod
# pytype: disable=bad-return-type
def auto_detect_unset_distributed_params(cls,
coordinator_address: Optional[str],
num_processes: Optional[int],
process_id: Optional[int],
local_device_ids: Optional[Sequence[int]]
) -> Tuple[Optional[str], Optional[int], Optional[int],
Optional[Sequence[int]]]:
if all(p is not None for p in (coordinator_address, num_processes,
process_id, local_device_ids)):
return (coordinator_address, num_processes, process_id,
local_device_ids)
env = next((env for env in cls._cluster_types if env.is_env_present()), None)
if env:
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
if coordinator_address is None:
coordinator_address = env.get_coordinator_address()
if num_processes is None:
num_processes = env.get_process_count()
if process_id is None:
process_id = env.get_process_id()
# Never automatically set local_device_ids on TPUs
# Defaults to single process per device if local_process_id is available.
# This only runs if we're in a managed distributed environment.
# Otherwise local_device_ids will remain unset,
# which will default to all devices being visible.
if (local_device_ids is None and not running_in_cloud_tpu_vm and
env.get_local_process_id() is not None):
local_device_ids = [env.get_local_process_id()] # type: ignore[list-item]
else:
logger.debug('Could not find a known environment for initializing distributed JAX. '
'Known environments: %s', ', '.join(e.__name__ for e in cls._cluster_types))
return (coordinator_address, num_processes, process_id, local_device_ids)
# pytype: enable=bad-return-type
@classmethod
def is_env_present(cls) -> bool:
"""Returns True if process is running in this cluster environment.
"""
raise NotImplementedError("ClusterEnv subclasses must implement is_env_present")
@classmethod
def get_coordinator_address(cls) -> str:
"""Returns address and port used by JAX to bootstrap.
Process id 0 will open a tcp socket at "hostname:port" where
all the proccesses will connect to initialize the distributed JAX service.
The selected port needs to be free.
:func:`get_coordinator_address` needs to return the same hostname and port on all the processes.
Returns:
"hostname:port"
"""
raise NotImplementedError("ClusterEnv subclasses must implement get_coordinator_address")
@classmethod
def get_process_count(cls) -> int:
raise NotImplementedError("ClusterEnv subclasses must implement get_process_count")
@classmethod
def get_process_id(cls) -> int:
raise NotImplementedError("ClusterEnv subclasses must implement get_process_id")
@classmethod
def get_local_process_id(cls) -> Optional[int]:
""" Get index of current process inside a host.
The method is only useful to support single device per process.
In that case, each process will see a local device whose ID is
the same as its local process ID.
If None, JAX will not restrict the visible devices.
"""
return None