3RNN/Lib/site-packages/tensorflow/dtensor/python/config.py

220 lines
7.8 KiB
Python
Raw Normal View History

2024-05-26 19:49:15 +02:00
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DTensor Configuration API."""
import os
from typing import List, Optional, Union
from tensorflow.python.eager import context
from tensorflow.python.framework import config as tf_config
from tensorflow.python.framework import device as tf_device
from tensorflow.python.util.tf_export import tf_export
_DT_CLIENT_ID = "DTENSOR_CLIENT_ID"
# DTENSOR_NUM_CLIENTS is removed, but some DTensor users still use this symbol.
_DT_NUM_CLIENTS = "DTENSOR_NUM_CLIENTS"
_DT_JOB_NAME = "DTENSOR_JOB_NAME"
_DT_JOBS = "DTENSOR_JOBS"
_DT_HEARTBEAT_ENABLED = "DTENSOR_ENABLE_HEARTBEAT"
# All functions in this file can be used before calling
# `tf.experimental.dtensor.initialize_accelerator_system`.
# -----------------------------------------------------------------------------
# Distributed training-related methods.
#
# Most users should use DTensor utility methods to create a mesh. The methods
# here are only for advanced users who want to fully customize their meshes.
# Note that local_devices and num_local_devices return the actual number of
# locally attached devices. The others are set through environment variables.
@tf_export("experimental.dtensor.local_devices", v1=[])
def local_devices(
device_type: str,
for_client_id: Optional[int] = None) -> List[tf_device.DeviceSpec]:
"""Returns a list of device specs configured on this client."""
if device_type.upper() not in ["CPU", "GPU", "TPU"]:
raise ValueError(f"Device type {device_type} is not CPU, GPU, or TPU.")
if for_client_id is None:
for_client_id = client_id()
# Return fully qualified device specs, sorted by increasing device index.
return [
tf_device.DeviceSpec( # pylint: disable=g-complex-comprehension
job=job_name(),
replica=0, # replica is deprecated and mostly hard-coded now.
task=for_client_id,
device_type=device_type,
device_index=i) for i in range(num_local_devices(device_type))
]
@tf_export("experimental.dtensor.num_local_devices", v1=[])
def num_local_devices(device_type: str) -> int:
"""Returns the number of devices of device_type configured on this client."""
# Reads from config because CPU and GPU can use logical devices.
if device_type.upper() in ["CPU", "GPU"]:
context_config = context.get_config()
return context_config.device_count[device_type.upper()]
return len(tf_config.list_physical_devices(device_type))
@tf_export("experimental.dtensor.num_global_devices", v1=[])
def num_global_devices(device_type: str) -> int:
"""Returns the number of devices of device_type in this DTensor cluster."""
return num_local_devices(device_type) * num_clients()
@tf_export("experimental.dtensor.client_id", v1=[])
def client_id() -> int:
"""Returns this client's ID."""
# If missing, assume running with a single client with client_id of 0.
client_id_value = int(os.environ.get(_DT_CLIENT_ID, "0"))
if client_id_value < 0:
raise ValueError(f"Environment variable {_DT_CLIENT_ID} "
f"must be >= 0, got {client_id_value}. ")
if client_id_value >= num_clients():
raise ValueError(f"Environment variable {_DT_CLIENT_ID} "
f"must be < {num_clients()}, got {client_id_value}")
return client_id_value
@tf_export("experimental.dtensor.num_clients", v1=[])
def num_clients() -> int:
"""Returns the number of clients in this DTensor cluster."""
if is_local_mode():
return 1
return len(jobs())
@tf_export("experimental.dtensor.job_name", v1=[])
def job_name() -> str:
"""Returns the job name used by all clients in this DTensor cluster."""
# If missing, assumes the program runs locally and use localhost as job name
# per TensorFlow convention.
return os.environ.get(_DT_JOB_NAME,
"localhost" if num_clients() == 1 else "worker")
@tf_export("experimental.dtensor.full_job_name", v1=[])
def full_job_name(task_id: Optional[int] = None) -> str:
"""Returns the fully qualified TF job name for this or another task."""
# If task_id is None, use this client's ID, which is equal to its task ID.
if task_id is None:
task_id = client_id()
# In local runs and unit tests, there should be exactly one client running
# on one TF task.
if num_clients() == 1 and task_id != 0:
raise ValueError(f"Unexpected task ID {task_id} in local runs")
return f"{job_name()}/replica:0/task:{task_id}"
def _bns_task_id(job: str) -> Union[int, str]:
"""Tries to extract an integer task ID from a job name.
For example, for `job` = '/.../tpu_worker/0:port_name', return 0.
Args:
job: A job name to extract task ID from.
Returns:
The task ID on success, or the original job name on failure.
"""
maybe_task_id = job.rsplit("/")[-1].rsplit(":")[0]
try:
return int(maybe_task_id)
except ValueError:
return job
@tf_export("experimental.dtensor.jobs", v1=[])
def jobs() -> List[str]:
"""Returns a list of job names of all clients in this DTensor cluster."""
d_jobs = os.environ.get(_DT_JOBS)
if d_jobs is None:
return []
d_jobs_list = d_jobs.split(",")
# Validate ordering for BNS style job names.
# For definition of BNS, refer to https://research.google/pubs/pub43438/.
if any([name.startswith("/bns/") for name in d_jobs_list]):
if d_jobs_list != sorted(d_jobs_list, key=_bns_task_id):
raise ValueError(
f"Unexpected DTENSOR_JOBS content {d_jobs}. Sort entries "
"in DTENSOR_JOBS because cluster construction relies on "
"the order.")
return d_jobs_list
@tf_export("experimental.dtensor.heartbeat_enabled", v1=[])
def heartbeat_enabled() -> bool:
"""Returns true if DTensor heartbeat service is enabled."""
return os.environ.get(_DT_HEARTBEAT_ENABLED, "true").lower() in ("true", "1")
def is_local_mode() -> bool:
"""Returns true if DTensor shall run in local mode."""
return not jobs()
def is_tpu_present() -> bool:
"""Returns true if TPU devices are present."""
# Check if TPU is present from initialized context.
# TPU_SYSTEM is a device that indicates TPUs are present.
tpu_system_devices = tf_config.list_physical_devices("TPU_SYSTEM")
return bool(tpu_system_devices)
def is_gpu_present() -> bool:
"""Returns true if TPU devices are present."""
return bool(tf_config.list_physical_devices("GPU"))
@tf_export("experimental.dtensor.preferred_device_type", v1=[])
def preferred_device_type() -> str:
"""Returns the preferred device type for the accelerators.
The returned device type is determined by checking the first present device
type from all supported device types in the order of 'TPU', 'GPU', 'CPU'.
"""
if is_tpu_present():
return "TPU"
elif is_gpu_present():
return "GPU"
return "CPU"
def use_multi_device_mode() -> bool:
"""Return True if environment indicates multi-device mode is enabled."""
return os.environ.get("DTENSOR_ENABLE_MULTI_DEVICE_EXPANSION", "0") != "0"
def gpu_use_nccl_communication() -> bool:
"""Return True if environment indicates NCCL shall be used for GPU."""
return os.environ.get("DTENSOR_GPU_USE_NCCL_COMMUNICATION", "0") != "0"
def backend_is_pw() -> bool:
"""Return True if environment indicates the backend is Pathways."""
return os.environ.get("DTENSOR_USE_PARALLEL_EXECUTOR") == "pw"