3RNN/Lib/site-packages/sklearn/datasets/tests/test_kddcup99.py
2024-05-26 19:49:15 +02:00

90 lines
2.5 KiB
Python

"""Test kddcup99 loader, if the data is available,
or if specifically requested via environment variable
(e.g. for CI jobs).
Only 'percent10' mode is tested, as the full data
is too big to use in unit-testing.
"""
from functools import partial
import pytest
from sklearn.datasets.tests.test_common import (
check_as_frame,
check_pandas_dependency_message,
check_return_X_y,
)
@pytest.mark.parametrize("as_frame", [True, False])
@pytest.mark.parametrize(
"subset, n_samples, n_features",
[
(None, 494021, 41),
("SA", 100655, 41),
("SF", 73237, 4),
("http", 58725, 3),
("smtp", 9571, 3),
],
)
def test_fetch_kddcup99_percent10(
fetch_kddcup99_fxt, as_frame, subset, n_samples, n_features
):
data = fetch_kddcup99_fxt(subset=subset, as_frame=as_frame)
assert data.data.shape == (n_samples, n_features)
assert data.target.shape == (n_samples,)
if as_frame:
assert data.frame.shape == (n_samples, n_features + 1)
assert data.DESCR.startswith(".. _kddcup99_dataset:")
def test_fetch_kddcup99_return_X_y(fetch_kddcup99_fxt):
fetch_func = partial(fetch_kddcup99_fxt, subset="smtp")
data = fetch_func()
check_return_X_y(data, fetch_func)
def test_fetch_kddcup99_as_frame(fetch_kddcup99_fxt):
bunch = fetch_kddcup99_fxt()
check_as_frame(bunch, fetch_kddcup99_fxt)
def test_fetch_kddcup99_shuffle(fetch_kddcup99_fxt):
dataset = fetch_kddcup99_fxt(
random_state=0,
subset="SA",
percent10=True,
)
dataset_shuffled = fetch_kddcup99_fxt(
random_state=0,
subset="SA",
shuffle=True,
percent10=True,
)
assert set(dataset["target"]) == set(dataset_shuffled["target"])
assert dataset_shuffled.data.shape == dataset.data.shape
assert dataset_shuffled.target.shape == dataset.target.shape
def test_pandas_dependency_message(fetch_kddcup99_fxt, hide_available_pandas):
check_pandas_dependency_message(fetch_kddcup99_fxt)
def test_corrupted_file_error_message(fetch_kddcup99_fxt, tmp_path):
"""Check that a nice error message is raised when cache is corrupted."""
kddcup99_dir = tmp_path / "kddcup99_10-py3"
kddcup99_dir.mkdir()
samples_path = kddcup99_dir / "samples"
with samples_path.open("wb") as f:
f.write(b"THIS IS CORRUPTED")
msg = (
"The cache for fetch_kddcup99 is invalid, please "
f"delete {str(kddcup99_dir)} and run the fetch_kddcup99 again"
)
with pytest.raises(OSError, match=msg):
fetch_kddcup99_fxt(data_home=str(tmp_path))