Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/array_serialization/serialization_test.py

318 lines
11 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for serialization and deserialization of GDA."""
import asyncio
import math
from functools import partial
import os
import pathlib
import tracemalloc as tm
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from jax import config
from jax._src import array
from jax.sharding import NamedSharding, GSPMDSharding
from jax.sharding import PartitionSpec as P
from jax.experimental.array_serialization import serialization
import numpy as np
import tensorstore as ts
config.parse_flags_with_absl()
prev_xla_flags = None
def setUpModule():
global prev_xla_flags
# This will control the CPU devices. On TPU we always have 2 devices
prev_xla_flags = jtu.set_host_platform_device_count(8)
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
prev_xla_flags()
class CheckpointTest(jtu.JaxTestCase):
def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir):
os.rename(temp_ckpt_dir, final_ckpt_dir)
@jtu.skip_on_devices('cpu')
def test_memory_consumption(self):
global_mesh = jtu.create_global_mesh((2, 4), ('x', 'y'))
inp_shape = (2_048, 4_096)
pspec = P('x', 'y')
num = math.prod(inp_shape)
sharding = NamedSharding(global_mesh, pspec)
src = jax.numpy.arange(num, dtype=np.int32).reshape(inp_shape) # 8e9
inp = array.make_array_from_callback(
inp_shape, sharding,
lambda idx: src[idx])
ckpt_dir = pathlib.Path(self.create_tempdir('memprof').full_path)
tspec = serialization.get_tensorstore_spec(str(ckpt_dir))
manager = serialization.GlobalAsyncCheckpointManager()
manager.serialize(
[inp], [tspec],
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
manager.wait_until_finished()
deserialize_with_byte_limit = serialization.async_deserialize(
sharding, tspec, inp_shape,
byte_limiter=serialization._LimitInFlightBytes(4_200_000))
tm.start()
asyncio.run(deserialize_with_byte_limit).block_until_ready()
unused_current, peak = tm.get_traced_memory()
# NB: some padding + tensorstore overhead. It should always be
# less than array size (2048 * 4096 * 4 = 32M)
self.assertLess(peak, 10_000_000)
deserialize_wo_limit = serialization.async_deserialize(
sharding, tspec, inp_shape)
tm.clear_traces()
# NB: call block_until_ready() is important here and above
# because otherwise this leads to racing condition and segfault with
# tensorstore attempting to dealloc using tracemalloc which is already
# destroyed.
asyncio.run(deserialize_wo_limit).block_until_ready()
unused_current, peak = tm.get_traced_memory()
# We load entire array in memory here.
self.assertGreater(peak, 30_000_000)
tm.stop()
def test_checkpointing_jax_array(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
inp_shape = (8, 2)
pspec = P('x', 'y')
num = math.prod(inp_shape)
# First Array
global_input_data1 = np.arange(num, dtype=np.int32).reshape(inp_shape)
a1 = array.make_array_from_callback(
inp_shape, NamedSharding(global_mesh, pspec),
lambda idx: global_input_data1[idx])
ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path)
ckpt_path1 = pathlib.Path(self.create_tempfile(f'{ckpt_dir}/first').full_path)
# Second Array
global_input_data2 = np.arange(
num, num + num, dtype=np.int32).reshape(inp_shape)
a2 = array.make_array_from_callback(
inp_shape, NamedSharding(global_mesh, pspec),
lambda idx: global_input_data2[idx])
ckpt_path2 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/second').full_path)
# Third Array
def cb3(_):
return np.array([], dtype=np.float32)
global_mesh1d = jtu.create_global_mesh((8,), ('x',))
a3 = array.make_array_from_callback(
(0,), NamedSharding(global_mesh1d, P(None)), cb3)
ckpt_path3 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/third').full_path)
ckpt_paths = [str(ckpt_path1), str(ckpt_path2), str(ckpt_path3)]
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
manager = serialization.GlobalAsyncCheckpointManager()
manager.serialize(
[a1, a2, a3], tspecs,
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
manager.wait_until_finished()
m1, m2, m3 = serialization.run_deserialization(
[NamedSharding(global_mesh, pspec),
NamedSharding(global_mesh, P('x')),
NamedSharding(global_mesh1d, P(None))],
tspecs)
self.assertIsInstance(m1, array.ArrayImpl)
self.assertArraysEqual(np.asarray(m1.addressable_shards[0].data),
np.array([[0], [2]], dtype=np.int32))
self.assertArraysEqual(np.asarray(m1.addressable_shards[1].data),
np.array([[1], [3]], dtype=np.int32))
self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
self.assertEqual(m1.dtype, np.int32)
self.assertIsInstance(m2, array.ArrayImpl)
self.assertArraysEqual(np.asarray(m2.addressable_shards[0].data),
np.array([[16, 17], [18, 19]], dtype=np.int32))
self.assertArraysEqual(np.asarray(m2.addressable_shards[1].data),
np.array([[16, 17], [18, 19]], dtype=np.int32))
self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
self.assertEqual(m2.dtype, np.int32)
self.assertIsInstance(m3, array.ArrayImpl)
for i, s in enumerate(m3.addressable_shards):
self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i)
self.assertArraysEqual(np.asarray(s.data), np.array([], dtype=np.float32))
self.assertEqual(m3.dtype, np.float32)
def test_checkpointing_with_bigger_shape_jax_array(self):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
num = math.prod(global_input_shape)
global_input_data1 = np.arange(num, dtype=np.int32).reshape(global_input_shape)
def cb1(index):
return global_input_data1[index]
arr = array.make_array_from_callback(
global_input_shape, NamedSharding(global_mesh, P('x', 'y')), cb1)
ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path)
ckpt_paths = [str(ckpt_dir)]
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
manager = serialization.GlobalAsyncCheckpointManager()
manager.serialize(
[arr], tspecs,
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
manager.wait_until_finished()
ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))
m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)],
[np.float32])
expected_data = {
0: np.array([[0], [2], [4]], dtype=np.float32),
1: np.array([[1], [3], [5]], dtype=np.float32),
2: np.array([[6], [8], [10]], dtype=np.float32),
3: np.array([[7], [9], [11]], dtype=np.float32),
4: np.array([[12], [14], [0]], dtype=np.float32),
5: np.array([[13], [15], [0]], dtype=np.float32),
6: np.array([[0], [0], [0]], dtype=np.float32),
7: np.array([[0], [0], [0]], dtype=np.float32),
}
for l in m1.addressable_shards:
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat))
m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32])
for l in m2.addressable_shards:
self.assertArraysEqual(l.data, global_input_data1.astype('float32'))
def test_checkpointing_scalar_jax_array(self):
global_mesh = jtu.create_global_mesh((2,), ('x'))
global_input_shape = ()
data = np.array(4)
s = NamedSharding(global_mesh, P(None))
array1 = array.make_array_from_callback(
global_input_shape, s, lambda idx: data[idx])
ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path)
ckpt_paths = [str(ckpt_dir)]
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
manager = serialization.GlobalAsyncCheckpointManager()
manager.serialize(
[array1], tspecs,
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
manager.wait_until_finished()
ds = NamedSharding(jtu.create_global_mesh((2,), ('x')), P(None))
m1, = serialization.run_deserialization(
[ds],
tspecs,
[()],
[np.float32]
)
for l in m1.addressable_shards:
self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
def test_deserialize_tensorstore_array_jax_array(self):
global_mesh = jtu.create_global_mesh((2,), ('x'))
data = np.arange(1024)
tspec = ts.array(data).spec()
m1, = serialization.run_deserialization(
[NamedSharding(global_mesh, P(None))],
[tspec]
)
for l in m1.addressable_shards:
self.assertArraysEqual(np.asarray(l.data), data)
def test_spec_has_metadata(self):
spec = {
'a': {
'b': 1,
'c': 2,
},
'd': 3,
'e': {
'a': 2,
'metadata': 3
},
'f': 4
}
self.assertTrue(serialization._spec_has_metadata(spec))
self.assertTrue(
serialization._spec_has_metadata({
'driver': 'zarr',
'kvstore': 'gfile',
'metadata': {
'chunks': 4,
'shape': (32, 64)
},
'one_more': 'thing'
}))
def test_spec_has_no_metadata(self):
spec = {
'a': {
'b': 1,
'c': 2,
},
'd': 3,
'e': {
'a': 2,
},
'f': 4
}
self.assertFalse(serialization._spec_has_metadata(spec))
def test_empty_spec_has_no_metadata(self):
spec = {}
self.assertFalse(serialization._spec_has_metadata(spec))
@parameterized.named_parameters(
('gcs', 'gs://my/ckpt/dir/path'),
('file', '/my/ckpt/dir/path')
)
def test_get_tensorstore_spec_ocdbt(self, path):
spec = serialization.get_tensorstore_spec(path, ocdbt=True)
is_gcs_path = path.startswith('gs://')
if is_gcs_path:
self.assertEqual(spec['kvstore']['base'], os.path.dirname(path))
else:
self.assertEqual(spec['kvstore']['base'],
f'file://{os.path.dirname(path)}')
self.assertEqual(spec['kvstore']['path'], 'path')
def test_get_tensorstore_spec_not_absolute_path(self):
path = 'my/ckpt/path'
with self.assertRaisesRegex(ValueError,
"Checkpoint path should be absolute"):
serialization.get_tensorstore_spec(path, ocdbt=True)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())