Traktor/myenv/Lib/site-packages/sklearn/datasets/tests/test_common.py

137 lines
4.3 KiB
Python
Raw Permalink Normal View History

2024-05-26 05:12:46 +02:00
"""Test loaders for common functionality."""
import inspect
import os
import numpy as np
import pytest
import sklearn.datasets
def is_pillow_installed():
try:
import PIL # noqa
return True
except ImportError:
return False
FETCH_PYTEST_MARKERS = {
"return_X_y": {
"fetch_20newsgroups": pytest.mark.xfail(
reason="X is a list and does not have a shape argument"
),
"fetch_openml": pytest.mark.xfail(
reason="fetch_opeml requires a dataset name or id"
),
"fetch_lfw_people": pytest.mark.skipif(
not is_pillow_installed(), reason="pillow is not installed"
),
},
"as_frame": {
"fetch_openml": pytest.mark.xfail(
reason="fetch_opeml requires a dataset name or id"
),
},
}
def check_pandas_dependency_message(fetch_func):
try:
import pandas # noqa
pytest.skip("This test requires pandas to not be installed")
except ImportError:
# Check that pandas is imported lazily and that an informative error
# message is raised when pandas is missing:
name = fetch_func.__name__
expected_msg = f"{name} with as_frame=True requires pandas"
with pytest.raises(ImportError, match=expected_msg):
fetch_func(as_frame=True)
def check_return_X_y(bunch, dataset_func):
X_y_tuple = dataset_func(return_X_y=True)
assert isinstance(X_y_tuple, tuple)
assert X_y_tuple[0].shape == bunch.data.shape
assert X_y_tuple[1].shape == bunch.target.shape
def check_as_frame(
bunch, dataset_func, expected_data_dtype=None, expected_target_dtype=None
):
pd = pytest.importorskip("pandas")
frame_bunch = dataset_func(as_frame=True)
assert hasattr(frame_bunch, "frame")
assert isinstance(frame_bunch.frame, pd.DataFrame)
assert isinstance(frame_bunch.data, pd.DataFrame)
assert frame_bunch.data.shape == bunch.data.shape
if frame_bunch.target.ndim > 1:
assert isinstance(frame_bunch.target, pd.DataFrame)
else:
assert isinstance(frame_bunch.target, pd.Series)
assert frame_bunch.target.shape[0] == bunch.target.shape[0]
if expected_data_dtype is not None:
assert np.all(frame_bunch.data.dtypes == expected_data_dtype)
if expected_target_dtype is not None:
assert np.all(frame_bunch.target.dtypes == expected_target_dtype)
# Test for return_X_y and as_frame=True
frame_X, frame_y = dataset_func(as_frame=True, return_X_y=True)
assert isinstance(frame_X, pd.DataFrame)
if frame_y.ndim > 1:
assert isinstance(frame_X, pd.DataFrame)
else:
assert isinstance(frame_y, pd.Series)
def _skip_network_tests():
return os.environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "1"
def _generate_func_supporting_param(param, dataset_type=("load", "fetch")):
markers_fetch = FETCH_PYTEST_MARKERS.get(param, {})
for name, obj in inspect.getmembers(sklearn.datasets):
if not inspect.isfunction(obj):
continue
is_dataset_type = any([name.startswith(t) for t in dataset_type])
is_support_param = param in inspect.signature(obj).parameters
if is_dataset_type and is_support_param:
# check if we should skip if we don't have network support
marks = [
pytest.mark.skipif(
condition=name.startswith("fetch") and _skip_network_tests(),
reason="Skip because fetcher requires internet network",
)
]
if name in markers_fetch:
marks.append(markers_fetch[name])
yield pytest.param(name, obj, marks=marks)
@pytest.mark.parametrize(
"name, dataset_func", _generate_func_supporting_param("return_X_y")
)
def test_common_check_return_X_y(name, dataset_func):
bunch = dataset_func()
check_return_X_y(bunch, dataset_func)
@pytest.mark.parametrize(
"name, dataset_func", _generate_func_supporting_param("as_frame")
)
def test_common_check_as_frame(name, dataset_func):
bunch = dataset_func()
check_as_frame(bunch, dataset_func)
@pytest.mark.parametrize(
"name, dataset_func", _generate_func_supporting_param("as_frame")
)
def test_common_check_pandas_dependency(name, dataset_func):
check_pandas_dependency_message(dataset_func)