90 lines
2.5 KiB
Python
90 lines
2.5 KiB
Python
"""Test kddcup99 loader, if the data is available,
|
|
or if specifically requested via environment variable
|
|
(e.g. for CI jobs).
|
|
|
|
Only 'percent10' mode is tested, as the full data
|
|
is too big to use in unit-testing.
|
|
"""
|
|
|
|
from functools import partial
|
|
|
|
import pytest
|
|
|
|
from sklearn.datasets.tests.test_common import (
|
|
check_as_frame,
|
|
check_pandas_dependency_message,
|
|
check_return_X_y,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("as_frame", [True, False])
|
|
@pytest.mark.parametrize(
|
|
"subset, n_samples, n_features",
|
|
[
|
|
(None, 494021, 41),
|
|
("SA", 100655, 41),
|
|
("SF", 73237, 4),
|
|
("http", 58725, 3),
|
|
("smtp", 9571, 3),
|
|
],
|
|
)
|
|
def test_fetch_kddcup99_percent10(
|
|
fetch_kddcup99_fxt, as_frame, subset, n_samples, n_features
|
|
):
|
|
data = fetch_kddcup99_fxt(subset=subset, as_frame=as_frame)
|
|
assert data.data.shape == (n_samples, n_features)
|
|
assert data.target.shape == (n_samples,)
|
|
if as_frame:
|
|
assert data.frame.shape == (n_samples, n_features + 1)
|
|
assert data.DESCR.startswith(".. _kddcup99_dataset:")
|
|
|
|
|
|
def test_fetch_kddcup99_return_X_y(fetch_kddcup99_fxt):
|
|
fetch_func = partial(fetch_kddcup99_fxt, subset="smtp")
|
|
data = fetch_func()
|
|
check_return_X_y(data, fetch_func)
|
|
|
|
|
|
def test_fetch_kddcup99_as_frame(fetch_kddcup99_fxt):
|
|
bunch = fetch_kddcup99_fxt()
|
|
check_as_frame(bunch, fetch_kddcup99_fxt)
|
|
|
|
|
|
def test_fetch_kddcup99_shuffle(fetch_kddcup99_fxt):
|
|
dataset = fetch_kddcup99_fxt(
|
|
random_state=0,
|
|
subset="SA",
|
|
percent10=True,
|
|
)
|
|
dataset_shuffled = fetch_kddcup99_fxt(
|
|
random_state=0,
|
|
subset="SA",
|
|
shuffle=True,
|
|
percent10=True,
|
|
)
|
|
assert set(dataset["target"]) == set(dataset_shuffled["target"])
|
|
assert dataset_shuffled.data.shape == dataset.data.shape
|
|
assert dataset_shuffled.target.shape == dataset.target.shape
|
|
|
|
|
|
def test_pandas_dependency_message(fetch_kddcup99_fxt, hide_available_pandas):
|
|
check_pandas_dependency_message(fetch_kddcup99_fxt)
|
|
|
|
|
|
def test_corrupted_file_error_message(fetch_kddcup99_fxt, tmp_path):
|
|
"""Check that a nice error message is raised when cache is corrupted."""
|
|
kddcup99_dir = tmp_path / "kddcup99_10-py3"
|
|
kddcup99_dir.mkdir()
|
|
samples_path = kddcup99_dir / "samples"
|
|
|
|
with samples_path.open("wb") as f:
|
|
f.write(b"THIS IS CORRUPTED")
|
|
|
|
msg = (
|
|
"The cache for fetch_kddcup99 is invalid, please "
|
|
f"delete {str(kddcup99_dir)} and run the fetch_kddcup99 again"
|
|
)
|
|
|
|
with pytest.raises(OSError, match=msg):
|
|
fetch_kddcup99_fxt(data_home=str(tmp_path))
|