311 lines
12 KiB
Python
311 lines
12 KiB
Python
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utilities to help with mesh creation."""
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
from absl import logging
|
|
import numpy as np
|
|
|
|
from tensorflow.dtensor.python import accelerator_util
|
|
from tensorflow.dtensor.python import api
|
|
from tensorflow.dtensor.python import config
|
|
from tensorflow.dtensor.python import layout
|
|
from tensorflow.dtensor.python import tpu_util
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import device as tf_device
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
def _print_context(num_global_devices: int, num_clients: int, client_id: int,
|
|
device_type: str, mesh: layout.Mesh) -> None:
|
|
logging.info('This is client %d of %d clients', client_id, num_clients)
|
|
logging.info('Number of global %s devices: %d', device_type.upper(),
|
|
num_global_devices)
|
|
# pylint: disable=protected-access
|
|
logging.info('Global device IDs: %s', mesh.global_device_ids())
|
|
logging.info('Local device IDs: %s', mesh.local_device_ids())
|
|
logging.info('Local devices: %s', mesh.local_devices())
|
|
# pylint: enable=protected-access
|
|
|
|
|
|
def _make_device_specs(
|
|
devices: Optional[List[Union[tf_device.DeviceSpec, str]]] = None,
|
|
device_type: Optional[str] = None
|
|
) -> Tuple[List[tf_device.DeviceSpec], str]:
|
|
"""Makes device specs for all local devices or from a provided list."""
|
|
|
|
if devices is None:
|
|
if device_type is None:
|
|
device_type = 'CPU'
|
|
devices = config.local_devices(device_type)
|
|
else:
|
|
if isinstance(devices[0], str):
|
|
devices = [tf_device.DeviceSpec.from_string(d) for d in devices]
|
|
if device_type is None:
|
|
device_type = devices[0].device_type
|
|
|
|
if device_type.upper() != devices[0].device_type.upper():
|
|
raise ValueError(
|
|
f'Conflicting devices {str(devices)} and device_type {device_type}'
|
|
)
|
|
|
|
return devices, device_type
|
|
|
|
|
|
@tf_export('experimental.dtensor.create_mesh', v1=[])
|
|
def create_mesh(
|
|
mesh_dims: Optional[Union[List[Tuple[str, int]], Dict[str, int]]] = None,
|
|
mesh_name: str = '',
|
|
devices: Optional[List[Union[tf_device.DeviceSpec, str]]] = None,
|
|
device_type: Optional[str] = None,
|
|
use_xla_spmd: bool = layout.USE_XLA_SPMD,
|
|
) -> layout.Mesh:
|
|
"""Creates a single-client mesh.
|
|
|
|
If both `mesh_dims` and `devices` are specified, they must match each otehr.
|
|
As a special case, when all arguments are missing, this creates a 1D CPU mesh
|
|
with an empty name, assigning all available devices to that dimension.
|
|
|
|
Args:
|
|
mesh_dims: A dict of dim_name: dim_size, or a list of (dim_name, dim_size)
|
|
tuples. Defaults to a single batch-parallel dimension called 'x' usin all
|
|
devices. As a special case, a single-element mesh_dims whose dim_size is
|
|
-1 also uses all devices. e.g. `{'x' : 4, 'y' : 1}` or `[('x', 4), ('y',
|
|
1)]`.
|
|
mesh_name: Name of the created mesh. Defaults to ''.
|
|
devices: String representations of devices to use. This is the device part
|
|
of tf.DeviceSpec, e.g. 'CPU:0'. Defaults to all available logical devices.
|
|
device_type: If `devices` is missing, the type of devices to use. Defaults
|
|
to 'CPU'.
|
|
use_xla_spmd: Boolean when True, will use XLA SPMD instead of DTensor SPMD.
|
|
|
|
Returns:
|
|
A single-client mesh created from specified or default arguments.
|
|
"""
|
|
device_specs, device_type = _make_device_specs(devices, device_type)
|
|
|
|
local_spec = tf_device.DeviceSpec(job=config.job_name(), replica=0, task=0)
|
|
device_specs = [local_spec.make_merged_spec(d) for d in device_specs]
|
|
|
|
if isinstance(mesh_dims, dict):
|
|
mesh_dims = list(mesh_dims.items())
|
|
if mesh_dims is None:
|
|
mesh_dims = [('x', len(device_specs))]
|
|
elif len(mesh_dims) == 1 and mesh_dims[0][1] == -1:
|
|
# Replace -1 dim_size in a 1D mesh will the number of all devices.
|
|
mesh_dims[0] = (mesh_dims[0][0], len(device_specs))
|
|
|
|
dim_names = [d[0] for d in mesh_dims]
|
|
shape = [d[1] for d in mesh_dims]
|
|
|
|
if np.prod(shape) != len(device_specs):
|
|
raise ValueError(f'length of devices ({len(device_specs)}) must be '
|
|
f'equal to total size of the mesh of shape {shape}')
|
|
|
|
global_device_ids = np.arange(len(device_specs)).reshape(shape)
|
|
local_device_ids = np.ravel(global_device_ids).tolist()
|
|
mesh = layout.Mesh(
|
|
dim_names=dim_names,
|
|
global_device_ids=global_device_ids,
|
|
local_device_ids=local_device_ids,
|
|
local_devices=device_specs,
|
|
mesh_name=mesh_name,
|
|
use_xla_spmd=use_xla_spmd)
|
|
_print_context(
|
|
num_global_devices=len(device_specs),
|
|
num_clients=1,
|
|
client_id=0,
|
|
device_type=device_type,
|
|
mesh=mesh)
|
|
return mesh
|
|
|
|
|
|
@tf_export('experimental.dtensor.create_distributed_mesh', v1=[])
|
|
def create_distributed_mesh(
|
|
mesh_dims: Union[List[Tuple[str, int]], Dict[str, int]],
|
|
mesh_name: str = '',
|
|
local_devices: Optional[List[Union[tf_device.DeviceSpec, str]]] = None,
|
|
device_type: Optional[str] = None,
|
|
use_xla_spmd: bool = layout.USE_XLA_SPMD,
|
|
) -> layout.Mesh:
|
|
"""Creates a distributed mesh.
|
|
|
|
This is similar to `create_mesh`, but with a different set of arguments to
|
|
create a mesh that spans evenly across a multi-client DTensor cluster.
|
|
|
|
For CPU and GPU meshes, users can choose to use fewer local devices than what
|
|
is available `local_devices`.
|
|
|
|
For TPU, only meshes that uses all TPU cores is supported by the DTensor
|
|
runtime.
|
|
|
|
Args:
|
|
mesh_dims: A dict of dim_name: dim_size, or a list of (dim_name, dim_size)
|
|
tuples. e.g. `{'x' : 4, 'y' : 1}` or `[('x', 4), ('y', 1)]`.
|
|
mesh_name: Name of the created mesh. Defaults to ''.
|
|
local_devices: String representations of devices to use. This is the device
|
|
part of tf.DeviceSpec, e.g. 'CPU:0'. Defaults to all available local
|
|
logical devices.
|
|
device_type: Type of device to build the mesh for. Defaults to 'CPU'.
|
|
Supported values are 'CPU', 'GPU', 'TPU'.6
|
|
use_xla_spmd: Boolean when True, will use XLA SPMD instead of DTensor SPMD.
|
|
|
|
Returns:
|
|
A mesh that spans evenly across all DTensor clients in the cluster.
|
|
"""
|
|
if isinstance(mesh_dims, dict):
|
|
mesh_dims = list(mesh_dims.items())
|
|
dim_names, shape = zip(*mesh_dims)
|
|
|
|
if not accelerator_util.is_initialized():
|
|
raise ValueError('Accelerators are uninitialized, please run '
|
|
'dtensor.initialize_accelerator_system() first.')
|
|
|
|
if device_type and device_type.upper() == 'TPU':
|
|
# TODO(b/185940495): Allow multi-mesh and partial on TPU.
|
|
# TPU meshes can only be configured through environment variables that
|
|
# reflect the actual TPU topology. Do not let users specify custom args.
|
|
if local_devices is not None:
|
|
raise ValueError(
|
|
f'Do not specify devices for {device_type.upper()} meshes. '
|
|
f'Using a partial list of devices for {device_type.upper()} '
|
|
f'is not supported.')
|
|
|
|
device_specs, device_type = _make_device_specs(local_devices, device_type)
|
|
|
|
if device_type.upper() in ['CPU', 'GPU']:
|
|
# For CPU and GPU meshes, user-specified args take precedence over env vars.
|
|
# This is particularly useful on single clients when users want to create
|
|
# meshes that use fewer logical devices than what's available.
|
|
|
|
local_spec = tf_device.DeviceSpec(
|
|
job=config.job_name(), replica=0, task=config.client_id())
|
|
device_specs = [local_spec.make_merged_spec(d) for d in device_specs]
|
|
|
|
# Assumes identical number of local devices per client.
|
|
num_global_devices = len(device_specs) * config.num_clients()
|
|
|
|
if np.prod(shape) != num_global_devices:
|
|
raise ValueError(
|
|
f'Global number of devices '
|
|
f'({len(device_specs)} per client * {config.num_clients()} clients '
|
|
f'= {num_global_devices}) must be '
|
|
f'equal to total size of the mesh of shape {shape}')
|
|
|
|
global_device_ids = np.arange(num_global_devices).reshape(shape)
|
|
flattened = np.ravel(global_device_ids).tolist()
|
|
start_idx = len(device_specs) * config.client_id()
|
|
local_device_ids = flattened[start_idx:start_idx + len(device_specs)]
|
|
|
|
mesh = layout.Mesh(
|
|
dim_names=dim_names,
|
|
global_device_ids=global_device_ids,
|
|
local_device_ids=local_device_ids,
|
|
local_devices=device_specs,
|
|
mesh_name=mesh_name,
|
|
use_xla_spmd=use_xla_spmd)
|
|
_print_context(num_global_devices, config.num_clients(), config.client_id(),
|
|
device_type, mesh)
|
|
return mesh
|
|
|
|
if device_type.upper() == 'TPU':
|
|
mesh = tpu_util.create_tpu_mesh(
|
|
mesh_dim_names=dim_names,
|
|
mesh_shape=shape,
|
|
mesh_name=mesh_name,
|
|
use_xla_spmd=use_xla_spmd)
|
|
_print_context(
|
|
config.num_global_devices(device_type), config.num_clients(),
|
|
config.client_id(), device_type, mesh)
|
|
return mesh
|
|
|
|
raise ValueError(f'Device type {device_type} is not CPU, GPU or TPU')
|
|
|
|
|
|
_BARRIER_DICT = {}
|
|
|
|
|
|
@tf_export('experimental.dtensor.barrier', v1=[])
|
|
def barrier(mesh: layout.Mesh,
|
|
barrier_name: Optional[str] = None,
|
|
timeout_in_ms: Optional[int] = None):
|
|
"""Runs a barrier on the mesh.
|
|
|
|
Upon returning from the barrier, all operations run before the barrier
|
|
would have completed across all clients. Currently we allocate a fully
|
|
sharded tensor with mesh shape and run an all_reduce on it.
|
|
|
|
Example:
|
|
|
|
A barrier can be used before application exit to ensure completion of pending
|
|
ops.
|
|
|
|
```python
|
|
|
|
x = [1, 2, 3]
|
|
x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1))
|
|
dtensor.barrier(mesh)
|
|
|
|
# At this point all devices on all clients in the mesh have completed
|
|
# operations before the barrier. Therefore it is OK to tear down the clients.
|
|
sys.exit()
|
|
```
|
|
|
|
Args:
|
|
mesh: The mesh to run the barrier on.
|
|
barrier_name: The name of the barrier. Mainly used for logging purpose.
|
|
timeout_in_ms: The timeout of the barrier in ms. If omitted, blocks
|
|
indefinitely till the barrier is reached from all clients.
|
|
"""
|
|
if barrier_name is None:
|
|
barrier_name = '(barrier)'
|
|
|
|
logging.info('entering barrier before op: %s', barrier_name)
|
|
|
|
# Make sure all ops are consumed before running the sync.
|
|
context.async_wait()
|
|
|
|
# Reduction on a fully sharded tensor requires all devices to participate
|
|
# and serves as a barrier on the mesh.
|
|
component = array_ops.reshape(1.0, [1] * len(mesh.shape()))
|
|
ones = api.pack([component] * mesh.num_local_devices(),
|
|
layout.Layout(mesh.dim_names, mesh))
|
|
|
|
mesh_size = math_ops.reduce_sum(ones)
|
|
if mesh_size != mesh.size:
|
|
raise ValueError(
|
|
'Global barrier produced wrong mesh size : {0} while mesh has actual'
|
|
'size : {1}'.format(mesh_size, mesh.size))
|
|
|
|
# TODO(hthu): This isn't strictly needed but might cause confusing behaviors
|
|
# from users. Consider dropping this if there is a `big` performance hit.
|
|
context.async_wait()
|
|
|
|
if context.context().coordination_service:
|
|
if timeout_in_ms is None:
|
|
timeout_in_ms = 24 * 60 * 60 * 1000 # 24 hours to stand in for infinite.
|
|
|
|
num_calls = _BARRIER_DICT.setdefault(barrier_name, 0)
|
|
_BARRIER_DICT[barrier_name] = num_calls + 1
|
|
|
|
barrier_id = f'{barrier_name}:{num_calls}'
|
|
context.context().wait_at_barrier(barrier_id, timeout_in_ms)
|
|
|
|
logging.info('finished running barrier across all clients after '
|
|
'op: %s', barrier_name)
|