Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/array_serialization/serialization.py
2023-06-19 00:49:18 +02:00

529 lines
20 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GlobalDeviceArray serialization and deserialization."""
import abc
import asyncio
import itertools
import logging
from functools import partial
import os
import re
import time
import threading
from typing import Awaitable, Any, Callable, Dict, Optional, Sequence, Union
import jax
from jax._src import distributed
from jax._src import array
from jax._src import sharding
from jax._src import sharding_impls
from jax._src import typing
import jax.numpy as jnp
import numpy as np
import tensorstore as ts
TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}})
_REMOVED_VALUE = 'Value removed'
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
_module_unique_count = itertools.count()
logger = logging.getLogger(__name__)
async def create_async_array_from_callback(
global_shape: array.Shape,
inp_sharding: sharding_impls.XLACompatibleSharding,
data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]],
):
device_to_index_map = inp_sharding.devices_indices_map(global_shape)
addressable_da = inp_sharding._addressable_device_assignment
future_arrays = [data_callback(device_to_index_map[d], d) # type: ignore
for d in addressable_da]
dbs = await asyncio.gather(*future_arrays)
return array.make_array_from_single_device_arrays(
global_shape, inp_sharding, dbs)
def _get_metadata(arr):
if arr.dtype == jnp.bfloat16:
# Tensorstore uses 'bfloat16', not '<V2'.
dtype = 'bfloat16'
else:
dtype = np.dtype(arr.dtype).str
local_shape = arr.addressable_data(0).shape
return {
'compressor': {
'id': 'zstd'
},
'shape': arr.shape,
'chunks': np.array(np.maximum(1, local_shape)),
'dtype': dtype,
}
def _spec_has_metadata(tree):
if not isinstance(tree, dict):
return False
return 'metadata' in tree or any(
_spec_has_metadata(subtree) for _, subtree in tree.items())
def _get_kvstore_for_gcs(ckpt_path: str):
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL)
if m is None:
raise ValueError('The ckpt_path should contain the bucket name and the '
f'file path inside the bucket. Got: {ckpt_path}')
gcs_bucket = m.group(1)
path_without_bucket = m.group(2)
return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket}
def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
# Normalize path to exclude trailing '/'. In GCS path case, we will need to
# fix the path prefix to add back the stripped '/'.
ckpt_path = os.path.normpath(ckpt_path).replace('gs:/', 'gs://')
is_gcs_path = ckpt_path.startswith('gs://')
spec = {'driver': 'zarr', 'kvstore': {}}
if ocdbt:
if not is_gcs_path and not os.path.isabs(ckpt_path):
raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}')
base_path = os.path.dirname(ckpt_path)
spec['kvstore'] = {
'driver': 'ocdbt',
'base': base_path if is_gcs_path else f'file://{base_path}',
'path': os.path.basename(ckpt_path),
}
else:
if is_gcs_path:
spec['kvstore'] = _get_kvstore_for_gcs(ckpt_path)
else:
spec['kvstore'] = {'driver': 'file', 'path': ckpt_path}
return spec
# Lifted from T5X.
class _LimitInFlightBytes:
"""Limits in-flight bytes when reading/writing checkpoints per process."""
def __init__(self, num_bytes):
self._max_bytes = num_bytes
self._available_bytes = num_bytes
self._cv = asyncio.Condition(lock=asyncio.Lock())
async def wait_for_bytes(self, requested_bytes):
if requested_bytes >= self._max_bytes:
raise ValueError('Requested more bytes than we reserved space for: '
f'{requested_bytes} > {self._max_bytes}')
async with self._cv:
await self._cv.wait_for(lambda: self._available_bytes > requested_bytes)
self._available_bytes -= requested_bytes
assert self._available_bytes >= 0
async def release_bytes(self, requested_bytes):
async with self._cv:
self._available_bytes += requested_bytes
assert self._available_bytes <= self._max_bytes
self._cv.notify_all()
async def async_serialize(
arr_inp, tensorstore_spec, commit_future=None, context=TS_CONTEXT
):
if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
arr_inp.is_fully_addressable):
raise ValueError('Passing fully addressable Arrays to a multiprocess '
'serialization is not allowed.')
# 'metadata' may not be present at the top level (for example, if we are using
# a 'cast' driver).
if not _spec_has_metadata(tensorstore_spec):
tensorstore_spec['metadata'] = _get_metadata(arr_inp)
if jax.process_index() == 0:
open_future = ts.open(
ts.Spec(tensorstore_spec),
create=True,
open=True,
context=context,
)
# Asynchronous case.
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(open_future)
else:
await open_future
# `ts.open` runs twice for process 0 because for the first time, we just get
# the future to be awaited upon in the background thread. The second one runs
# with `assume_metadata=True` which does no I/O operation and returns the
# tensorstore object.
# For every process other than `0`, we open with `assume_metadata=True`.
t = await ts.open(
ts.Spec(tensorstore_spec),
open=True,
assume_metadata=True,
context=context,
)
async def _write_array(shard):
if shard.replica_id == 0:
write_future = t[shard.index].write(shard.data)
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(write_future.commit)
await write_future.copy
else:
await write_future.commit
if isinstance(arr_inp, array.ArrayImpl):
local_shards = arr_inp.addressable_shards
else:
local_shards = arr_inp.addressable_shards
future_write_state = jax.tree_util.tree_map(_write_array, local_shards)
return await asyncio.gather(*future_write_state)
def run_serialization(arrays, tensorstore_specs):
async def _run_serializer():
future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs)
return await asyncio.gather(*future_writer)
asyncio.run(_run_serializer())
def estimate_read_memory_footprint(t: ts.TensorStore,
domain: ts.IndexDomain) -> int:
rank = t.rank
num_bytes = t.dtype.numpy_dtype.itemsize
chunk_template = t.chunk_layout.read_chunk_template
if domain is None:
domain = t.domain
origin = domain.origin
shape = domain.shape
chunk_origin = chunk_template.origin
chunk_shape = chunk_template.shape
# Some TensorStore drivers are not chunked, e.g. the inline 'array' driver.
# For those, instead of returning a near-infinite memory footprint, estimate
# the footprint as the entire shape.
for i in range(rank):
if not chunk_template[i].finite:
return domain.size * num_bytes
# Otherwise, if we have a chunked driver, estimate based on chunk size.
for i in range(rank):
origin_value = origin[i]
chunk_origin_value = chunk_origin[i]
chunk_size = chunk_shape[i]
lower = origin_value - chunk_origin_value
upper = origin_value + shape[i] - chunk_origin_value
lower_aligned = lower // chunk_size * chunk_size
upper_aligned = -(-upper // chunk_size) * chunk_size
num_bytes *= (upper_aligned - lower_aligned)
return num_bytes
async def async_deserialize(
in_sharding: sharding_impls.XLACompatibleSharding,
tensorstore_spec: Union[ts.Spec, Dict[str, Any]],
global_shape: Optional[Sequence[int]] = None,
dtype=None,
byte_limiter: Optional[_LimitInFlightBytes] = None,
context=TS_CONTEXT,
assume_metadata: bool = False,
):
t = await ts.open(
tensorstore_spec,
open=True,
assume_metadata=assume_metadata,
context=context,
)
shape = t.shape if global_shape is None else global_shape
new_shard_shape = in_sharding.shard_shape(tuple(shape))
async def cb(index: array.Index, device: jax.Device):
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)
requested_bytes = estimate_read_memory_footprint(t, restricted_domain)
# Limit the bytes read for every shard.
if byte_limiter is not None:
await byte_limiter.wait_for_bytes(requested_bytes)
# This maybe needed because the shape the array was saved with is smaller
# than the requested shape of the array in which it will be reloaded. So
# the extra values will be filled with 0s.
out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][
restricted_domain].write(t[restricted_domain])
if dtype is not None:
# Cast while reloading on process to avoid 2 copies on device if the
# casting is done on device.
out = out.astype(dtype)
result = jax.device_put(out, device)
if byte_limiter is not None:
# NB: `out` actually might not be ready for garbage collection by the
# time we call release_bytes . Thus peak memory usage still might grow
# beyond what byte_limiter limit suggests it should. The simplest option
# would be to call `result.block_until_ready()`` here. However it
# also comes with ~15-20% perf penalty as we would be waiting for CPU->GPU
# transfer instead of loading data. In the future, if memory pressure
# becomes a problem, we can instead instrument bytelimiter to
# keep track of all in-flight tensors and only block_until_ready, if byte
# limiter hits the limit to get reduced memory usage, without loosing
# performance in common use cases.
await byte_limiter.release_bytes(requested_bytes)
return result
return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
def run_deserialization(shardings: Sequence[sharding.Sharding],
tensorstore_specs: Sequence[Dict[str, Any]],
global_shapes: Optional[Sequence[array.Shape]] = None,
dtypes: Optional[Sequence[typing.DTypeLike]] = None,
concurrent_gb: int = 32):
concurrent_bytes = concurrent_gb * 10**9
async def _run_deserializer():
# Object should be created once per process.
byte_limiter = _LimitInFlightBytes(concurrent_bytes)
future_arrays = jax.tree_util.tree_map(
partial(async_deserialize, byte_limiter=byte_limiter),
shardings, tensorstore_specs,
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
return await asyncio.gather(*future_arrays)
return asyncio.run(_run_deserializer())
def _get_key(key: str):
return f'tensorstore_checkpoint_{key}'
class GlobalAsyncCheckpointManagerBase(metaclass=abc.ABCMeta):
"""Interface for checkpointing GDAs asynchronously.
This class manages the state of an ongoing asynchronous checkpoint.
For example, say a checkpoint happens on every step. If you checkpoint on
step 1 and after some computation the model is on checkpoint 2. But step 1's
checkpoint hasn't finished committing to the storage layer yet. So until that
is finished, checkpoint for step 2 will need to be blocked. Maintaining a
class allows to maintain that state.
Example:
Below is a simplified training loop:
```
# Call this at the start of your program.
jax.distributed.initialize()
manager = GlobalAsyncCheckpointManager()
# Restore checkpoint if available or initialize the train_state from
# init_fn().
train_state = manager.deserialize(...)
while ...:
if step % num_steps_between_checkpoints == 0:
manager.serialize(train_state, temp_checkpoint_dir=...,
final_checkpoint_dir=...)
train_state = train_step(train_state, input)
# This is a non-blocking call.
manager.check_for_errors()
manager.serialize(train_state, temp_checkpoint_dir=...,
final_checkpoint_dir=...)
# Wait before the end of the program for the checkpoint to finish. This is a
# blocking call.
manager.wait_until_finished()
```
"""
@abc.abstractmethod
def check_for_errors(self):
"""Checks if any errors have been raised in the child thread.
This is a non-blocking call that can be called in the main thread.
"""
@abc.abstractmethod
def wait_until_finished(self):
"""Blocks until serialization has finished."""
@abc.abstractmethod
def serialize(self, arrays, tensorstore_specs, *,
on_commit_callback: Callable[[], None]):
"""Serializes GDAs to TensorStore."""
@abc.abstractmethod
def deserialize(self, shardings: Sequence[sharding.Sharding],
tensorstore_specs: Sequence[Dict[str, Any]],
global_shapes: Optional[Sequence[array.Shape]] = None,
dtypes: Optional[Sequence[typing.DTypeLike]] = None):
"""Deserializes GDAs from TensorStore."""
class AsyncManager:
def __init__(self, timeout_secs=300):
self._timeout_secs = timeout_secs
self._timeout_in_ms = self._timeout_secs * 1000
self._commit_futures = None
self._thread = None
self._exception = None
if jax.process_count() > 1 and distributed.global_state.client is None:
raise ValueError('Please initialize the distributed system via '
'`jax.distributed.initialize()` at the start of your '
'program.')
if jax.process_count() > 1:
self._client = distributed.global_state.client
self._count = None
def __del__(self):
if self._thread is not None and self._thread.is_alive():
logger.warning('Please add `.wait_until_finished()` in the main thread '
'before your program finishes because there is a '
'possibility of losing errors raised if the '
'this class is deleted before writing is completed.')
def _thread_func(self):
try:
current_process = jax.process_index()
process_count = jax.process_count()
logger.info('Starting commit to storage layer by process: %s',
current_process)
thread_start_time = time.time()
for future in self._commit_futures:
future.result()
logger.info('Finished committing to storage layer by process: %s',
current_process)
if process_count > 1:
# All processes will wait at the barrier. When all processes are at the
# barrier, the barrier will be satisfied. If not, then it will timeout.
key_for_barrier = _get_key(self._count)
logger.info('Key used for barrier is %s for process %s',
key_for_barrier, current_process)
self._client.wait_at_barrier(key_for_barrier, self._timeout_in_ms)
logger.info('Finished waiting at barrier for process %s',
current_process)
if current_process == 0:
self._on_commit_callback()
logger.info('on_commit_callback successfully ran!')
if process_count > 1:
self._client.key_value_set(key_for_barrier, _CHECKPOINT_SUCCESS)
logger.info('Process 0 successfully set key %s in the kv store',
key_for_barrier)
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/thread_duration_sec',
time.time() - thread_start_time)
except Exception as e:
self._exception = e
def _start_async_commit(self, on_commit_callback):
self._count = next(_module_unique_count)
self._on_commit_callback = on_commit_callback
self._thread = threading.Thread(target=self._thread_func)
self._thread.start()
def check_for_errors(self):
if self._exception is not None:
# Clears self._exception so it is only raised once.
exception = self._exception
self._exception = None
raise exception # pylint: disable=raising-bad-type
def wait_until_finished(self):
if self._thread is not None:
self._thread.join()
self._thread = None
logger.info('Thread joined successfully')
self.check_for_errors()
logger.info('Error check finished successfully')
if jax.process_count() > 1 and self._count is not None:
# Block until process 0 writes success value to the key value store.
# If it fails to write it, then `blocking_key_value_get` will time out.
get_key = _get_key(self._count)
self._client.blocking_key_value_get(get_key, self._timeout_in_ms)
logger.info('blocking_key_value_get on key %s was successfully '
'completed.', get_key)
def _add_futures(self, futures: Sequence[asyncio.Future]):
self._commit_futures = futures
class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBase):
"""Responsible for serializing GDAs via TensorStore."""
def serialize(self, arrays, tensorstore_specs, *, on_commit_callback):
"""Serializes GlobalDeviceArrays or Arrays via TensorStore asynchronously.
TensorStore writes to a storage layer in 2 steps:
* Reading/copying from the source after which the source can be modified.
* Returns a copy future.
* Writing/committing to the storage layer.
* Returns a commit future.
In asynchronous mode, the serialization waits for the commit future to
finish in a separate thread allowing other computation to proceed.
Args:
arrays: GlobalDeviceArrays or Arrays that should be serialized.
tensorstore_specs: TensorStore specs that are used to serialize GDAs or
Arrays.
on_commit_callback: This callback will be executed after all processes
have finished writing their checkpoints to disk. Filesystems where
atomic rename operations are supported, you can rename from the
temporary directory to the final directory. On GCS, you write to the
final directory directly and in `on_commit_callback` you write a
success file indicating that the serialization was successful because
GCS does not support atomic rename operations.
"""
logger.info('Waiting for previous serialization to finish.')
self.wait_until_finished()
commit_futures = [[] for _ in range(len(tensorstore_specs))]
async def _run_serializer():
future_writer = jax.tree_util.tree_map(
async_serialize, arrays, tensorstore_specs, commit_futures)
return await asyncio.gather(*future_writer)
asyncio.run(_run_serializer())
self._add_futures(jax.tree_util.tree_flatten(commit_futures)[0])
# Used in wait_until_finished to check on process != 0, if the checkpoint
# has finished writing.
self._start_async_commit(on_commit_callback)
def deserialize(self, shardings: Sequence[sharding.Sharding],
tensorstore_specs: Sequence[Dict[str, Any]],
global_shapes: Optional[Sequence[array.Shape]] = None,
dtypes: Optional[Sequence[typing.DTypeLike]] = None,
concurrent_gb: int = 32):
self.wait_until_finished()
return run_deserialization(shardings, tensorstore_specs,
global_shapes, dtypes, concurrent_gb)