74 lines
2.3 KiB
Python
74 lines
2.3 KiB
Python
import warnings
|
|
from itertools import chain
|
|
|
|
import pytest
|
|
|
|
from sklearn import config_context
|
|
from sklearn.utils._chunking import gen_even_slices, get_chunk_n_rows
|
|
from sklearn.utils._testing import assert_array_equal
|
|
|
|
|
|
def test_gen_even_slices():
|
|
# check that gen_even_slices contains all samples
|
|
some_range = range(10)
|
|
joined_range = list(chain(*[some_range[slice] for slice in gen_even_slices(10, 3)]))
|
|
assert_array_equal(some_range, joined_range)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("row_bytes", "max_n_rows", "working_memory", "expected"),
|
|
[
|
|
(1024, None, 1, 1024),
|
|
(1024, None, 0.99999999, 1023),
|
|
(1023, None, 1, 1025),
|
|
(1025, None, 1, 1023),
|
|
(1024, None, 2, 2048),
|
|
(1024, 7, 1, 7),
|
|
(1024 * 1024, None, 1, 1),
|
|
],
|
|
)
|
|
def test_get_chunk_n_rows(row_bytes, max_n_rows, working_memory, expected):
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("error", UserWarning)
|
|
actual = get_chunk_n_rows(
|
|
row_bytes=row_bytes,
|
|
max_n_rows=max_n_rows,
|
|
working_memory=working_memory,
|
|
)
|
|
|
|
assert actual == expected
|
|
assert type(actual) is type(expected)
|
|
with config_context(working_memory=working_memory):
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("error", UserWarning)
|
|
actual = get_chunk_n_rows(row_bytes=row_bytes, max_n_rows=max_n_rows)
|
|
assert actual == expected
|
|
assert type(actual) is type(expected)
|
|
|
|
|
|
def test_get_chunk_n_rows_warns():
|
|
"""Check that warning is raised when working_memory is too low."""
|
|
row_bytes = 1024 * 1024 + 1
|
|
max_n_rows = None
|
|
working_memory = 1
|
|
expected = 1
|
|
|
|
warn_msg = (
|
|
"Could not adhere to working_memory config. Currently 1MiB, 2MiB required."
|
|
)
|
|
with pytest.warns(UserWarning, match=warn_msg):
|
|
actual = get_chunk_n_rows(
|
|
row_bytes=row_bytes,
|
|
max_n_rows=max_n_rows,
|
|
working_memory=working_memory,
|
|
)
|
|
|
|
assert actual == expected
|
|
assert type(actual) is type(expected)
|
|
|
|
with config_context(working_memory=working_memory):
|
|
with pytest.warns(UserWarning, match=warn_msg):
|
|
actual = get_chunk_n_rows(row_bytes=row_bytes, max_n_rows=max_n_rows)
|
|
assert actual == expected
|
|
assert type(actual) is type(expected)
|